From c313be8e955ce1dc46c28d1c694936156a63d441 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 8 Mar 2019 12:13:58 +0000 Subject: [PATCH] rm data/param --- src/cuda/curnn.jl | 12 ++++++------ src/layers/basic.jl | 4 ++-- src/layers/conv.jl | 8 ++++---- src/layers/normalise.jl | 20 ++++++++++---------- src/layers/recurrent.jl | 12 ++++++------ src/optimise/optimisers.jl | 10 +++++----- src/treelike.jl | 2 +- 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 7ad14102..02f78a96 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -287,13 +287,13 @@ end (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) @adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) - reserve, result = forwardTrain(desc(m), data(x), data(h)) + reserve, result = forwardTrain(desc(m), x, h) result, function (Δ) y, ho = result dy, dho = Δ - h_ = hBatch(x, data(h)) + h_ = hBatch(x, h) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) end end @@ -303,10 +303,10 @@ end result, function (Δ) y, ho = result dy, dho, dco = Δ - h_ = hBatch(x, data(h)) - c_ = hBatch(x, data(c)) + h_ = hBatch(x, h) + c_ = hBatch(x, c) dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) nobacksies(:RNN, (dx, unbroadcast(h, dh), unbroadcast(c, dc), transpose(dWi), transpose(dWh), db)) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e640bb24..dea0089f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -89,7 +89,7 @@ Dense(W, b) = Dense(W, b, identity) function Dense(in::Integer, out::Integer, σ = identity; initW = glorot_uniform, initb = zeros) - return Dense(param(initW(out, in)), param(initb(out)), σ) + return Dense(initW(out, in), initb(out), σ) end @treelike Dense @@ -129,7 +129,7 @@ struct Diagonal{T} end Diagonal(in::Integer; initα = ones, initβ = zeros) = - Diagonal(param(initα(in)), param(initβ(in))) + Diagonal(initα(in), initβ(in)) @treelike Diagonal diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a59a8c6a..d1e7ab97 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -42,7 +42,7 @@ end Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = - Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, + Conv(init(k..., ch...), zeros(ch[2]), σ, stride = stride, pad = pad, dilation = dilation) @treelike Conv @@ -97,7 +97,7 @@ end ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = -ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ, +ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ, stride = stride, pad = pad, dilation = dilation) @treelike ConvTranspose @@ -168,14 +168,14 @@ end DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = - DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ, + DepthwiseConv(init(k..., 1, ch), zeros(ch), σ, stride = stride, pad = pad, dilation=dilation) DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform, stride::NTuple{N,Integer} = map(_->1,k), pad::NTuple{N,Integer} = map(_->0,2 .* k), dilation::NTuple{N,Integer} = map(_->1,k)) where N = - DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ, + DepthwiseConv(init(k..., ch[2], ch[1]), zeros(ch[2]*ch[1]), σ, stride = stride, pad = pad) @treelike DepthwiseConv diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 7c11d411..4ee6b758 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -138,7 +138,7 @@ end BatchNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = - BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), + BatchNorm(λ, initβ(chs), initγ(chs), zeros(chs), ones(chs), ϵ, momentum, true) function (BN::BatchNorm)(x) @@ -160,11 +160,11 @@ function (BN::BatchNorm)(x) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) σ² = sum((x .- μ) .^ 2, dims = axes) ./ m - ϵ = data(convert(T, BN.ϵ)) + ϵ = convert(T, BN.ϵ) # update moving mean/std - mtm = data(convert(T, BN.momentum)) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :) + mtm = convert(T, BN.momentum) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(μ, :) + BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(σ², :) end let λ = BN.λ @@ -231,7 +231,7 @@ end InstanceNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = - InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), + InstanceNorm(λ, initβ(chs), initγ(chs), zeros(chs), ones(chs), ϵ, momentum, true) function (in::InstanceNorm)(x) @@ -256,15 +256,15 @@ function (in::InstanceNorm)(x) else T = eltype(x) - ϵ = data(convert(T, in.ϵ)) + ϵ = convert(T, in.ϵ) axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) μ = mean(x, dims = axes) σ² = mean((x .- μ) .^ 2, dims = axes) # update moving mean/std - mtm = data(convert(T, in.momentum)) - in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) - in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) + mtm = convert(T, in.momentum) + in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(μ, (c, bs)), dims = 2), dims=2) + in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(σ², (c, bs))), dims = 2), dims=2) end let λ = in.λ diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 03e3b323..70ff3d98 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -68,8 +68,8 @@ end RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = - RNNCell(σ, param(init(out, in)), param(init(out, out)), - param(init(out)), param(zeros(out))) + RNNCell(σ, init(out, in), init(out, out), + init(out), zeros(out)) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -107,8 +107,8 @@ end function LSTMCell(in::Integer, out::Integer; init = glorot_uniform) - cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)), - param(zeros(out)), param(zeros(out))) + cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4), + zeros(out), zeros(out)) cell.b.data[gate(out, 2)] .= 1 return cell end @@ -153,8 +153,8 @@ mutable struct GRUCell{A,V} end GRUCell(in, out; init = glorot_uniform) = - GRUCell(param(init(out*3, in)), param(init(out*3, out)), - param(init(out*3)), param(zeros(out))) + GRUCell(init(out * 3, in), init(out * 3, out), + init(out * 3), zeros(out)) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index aa2db1c5..da536ac6 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) function apply!(o::Momentum, x, Δ) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(data(x)) + v = get!(o.velocity, x, zero(x))::typeof(x) @. v = ρ * v - η * Δ @. Δ = -v end @@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) function apply!(o::Nesterov, x, Δ) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(data(x)) + v = get!(o.velocity, x, zero(x))::typeof(x) d = @. ρ^2 * v - (1+ρ) * η * Δ @. v = ρ*v - η*Δ @. Δ = -d @@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) function apply!(o::RMSProp, x, Δ) η, ρ = o.eta, o.rho - acc = get!(o.acc, x, zero(x))::typeof(data(x)) + acc = get!(o.acc, x, zero(x))::typeof(x) @. acc = ρ * acc + (1 - ρ) * Δ^2 @. Δ *= η / (√acc + ϵ) end @@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) function apply!(o::ADAGrad, x, Δ) η = o.eta - acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x)) + acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) @. acc += Δ^2 @. Δ *= η / (√acc + ϵ) end @@ -323,5 +323,5 @@ WeightDecay() = WeightDecay(0) function apply!(o::WeightDecay, x, Δ) wd = o.wd - @. Δ += wd * data(x) + @. Δ += wd * x end diff --git a/src/treelike.jl b/src/treelike.jl index 07935e55..6500c644 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -51,7 +51,7 @@ function loadparams!(m, xs) for (p, x) in zip(params(m), xs) size(p) == size(x) || error("Expected param size $(size(p)), got $(size(x))") - copyto!(data(p), data(x)) + copyto!(p, x) end end