rm last uses of param/data

This commit is contained in:
Mike Innes 2019-08-19 15:09:32 +01:00
parent a76e4d128b
commit 9590aa63e3
7 changed files with 27 additions and 28 deletions

View File

@ -1,6 +1,5 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
import ..Flux: data
using LinearAlgebra
mutable struct DropoutDesc
@ -197,4 +196,4 @@ end
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing)

View File

@ -242,9 +242,9 @@ CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
copy_transpose!(Wi, m.Wi)
copy_transpose!(Wh, m.Wh)
copy_transpose!(d.bias, m.b)
return
end
@ -301,7 +301,7 @@ for RNN in (CuRNN, CuGRU)
end
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
reserve, result = forwardTrain(desc(m), x, h, c)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ

View File

@ -8,11 +8,11 @@ using Zygote
CuArrays.allowscalar(false)
x = param(randn(5, 5))
x = randn(5, 5)
cx = gpu(x)
@test cx isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
@test Flux.onecold(gpu([1.0, 2.0, 3.0])) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x)
@ -29,7 +29,7 @@ x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
xs = param(rand(5,5))
xs = rand(5, 5)
ys = Flux.onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)

View File

@ -12,7 +12,7 @@ trainmode(f, x...) = forward(f, x...)[1]
y = trainmode(m, x)
cy = trainmode(cm, cx)
@test cpu(data(cy)) data(y)
@test cpu(cy) y
g = gradient(()->sum(m(x)), params(m))
cg = gradient(()->sum(cm(cx)), params(cm))
@ -32,7 +32,7 @@ trainmode(f, x...) = forward(f, x...)[1]
@test cy isa CuArray{Float32,2}
@test cpu(data(cy)) data(y)
@test cpu(cy) y
g = gradient(()->sum(m(x)), params(m))
cg = gradient(()->sum(cm(cx)), params(cm))

View File

@ -8,8 +8,8 @@ using Flux, CuArrays, Test
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
rand(10) :
rand(10, batch_size)
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))

View File

@ -27,7 +27,7 @@ end
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
m.weight[:] .= 1.0
m.bias[:] .= 0.0
y_hat = Flux.data(m(r))[:,:,1,1]
y_hat = m(r)[:,:,1,1]
@test size(y_hat) == (27, 29)
@test y_hat[1, 1] 6.0
@test y_hat[2, 2] 9.0

View File

@ -73,26 +73,26 @@ end
end
# with activation function
let m = BatchNorm(2, sigmoid), x = param([1.0 3.0 5.0;
2.0 4.0 6.0])
let m = BatchNorm(2, sigmoid), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
y = trainmode(m, x)
y = m(x)
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
end
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
let m = BatchNorm(2), x = reshape(1:6, 3, 2, 1)
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
@test m(x) == y
end
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
let m = BatchNorm(2), x = reshape(1:12, 2, 3, 2, 1)
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
@test m(x) == y
end
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
let m = BatchNorm(2), x = reshape(1:24, 2, 2, 3, 2, 1)
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
@test m(x) == y
@ -156,7 +156,7 @@ end
y = trainmode(m, x)
y = m(x)
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
end
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
@ -193,7 +193,7 @@ end
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2),
x = param(reshape(collect(1:prod(sizes)), sizes))
x = reshape(collect(1:prod(sizes)), sizes)
x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32)
@ -238,7 +238,7 @@ end
end
# with activation function
let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
x = reshape(collect(1:prod(sizes)), sizes)
x = Float64.(x)
μ_affine_shape = ones(Int,length(sizes) + 1)
μ_affine_shape[end-1] = 2 # Number of groups
@ -254,12 +254,12 @@ end
y = trainmode(m, x)
y = m(x)
x_ = reshape(x,affine_shape...)
out = reshape(data(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ))),og_shape)
out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape)
@test isapprox(y, out, atol = 1.0e-7)
end
let m = GroupNorm(2,2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
x = reshape(collect(1:prod(sizes)), sizes)
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
@ -267,7 +267,7 @@ end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
x = reshape(collect(1:prod(sizes)), sizes)
y = m(x)
@test size(m.μ) == (m.G,1)
@test size(m.σ²) == (m.G,1)
@ -276,13 +276,13 @@ end
# show that group norm is the same as instance norm when the group size is the same as the number of channels
let IN = InstanceNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,5),
x = param(reshape(collect(1:prod(sizes)), sizes))
x = reshape(collect(1:prod(sizes)), sizes)
@test IN(x) GN(x)
end
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
let BN = BatchNorm(4), GN = GroupNorm(4,4), sizes = (2,2,3,4,1),
x = param(reshape(collect(1:prod(sizes)), sizes))
x = reshape(collect(1:prod(sizes)), sizes)
@test BN(x) GN(x)
end