rm data/param
This commit is contained in:
parent
aa4d221f8c
commit
c313be8e95
@ -287,13 +287,13 @@ end
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||
reserve, result = forwardTrain(desc(m), x, h)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
@ -303,10 +303,10 @@ end
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
h_ = hBatch(x, h)
|
||||
c_ = hBatch(x, c)
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
|
@ -89,7 +89,7 @@ Dense(W, b) = Dense(W, b, identity)
|
||||
|
||||
function Dense(in::Integer, out::Integer, σ = identity;
|
||||
initW = glorot_uniform, initb = zeros)
|
||||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||
return Dense(initW(out, in), initb(out), σ)
|
||||
end
|
||||
|
||||
@treelike Dense
|
||||
@ -129,7 +129,7 @@ struct Diagonal{T}
|
||||
end
|
||||
|
||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
Diagonal(initα(in), initβ(in))
|
||||
|
||||
@treelike Diagonal
|
||||
|
||||
|
@ -42,7 +42,7 @@ end
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
Conv(init(k..., ch...), zeros(ch[2]), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike Conv
|
||||
@ -97,7 +97,7 @@ end
|
||||
|
||||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
||||
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike ConvTranspose
|
||||
@ -168,14 +168,14 @@ end
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||
DepthwiseConv(init(k..., 1, ch), zeros(ch), σ,
|
||||
stride = stride, pad = pad, dilation=dilation)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,2 .* k),
|
||||
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
|
||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||
DepthwiseConv(init(k..., ch[2], ch[1]), zeros(ch[2]*ch[1]), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
@treelike DepthwiseConv
|
||||
|
@ -138,7 +138,7 @@ end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
@ -160,11 +160,11 @@ function (BN::BatchNorm)(x)
|
||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
ϵ = convert(T, BN.ϵ)
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
|
||||
mtm = convert(T, BN.momentum)
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(μ, :)
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(σ², :)
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
@ -231,7 +231,7 @@ end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
function (in::InstanceNorm)(x)
|
||||
@ -256,15 +256,15 @@ function (in::InstanceNorm)(x)
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, in.ϵ))
|
||||
ϵ = convert(T, in.ϵ)
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, in.momentum))
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||
mtm = convert(T, in.momentum)
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(μ, (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(σ², (c, bs))), dims = 2), dims=2)
|
||||
end
|
||||
|
||||
let λ = in.λ
|
||||
|
@ -68,8 +68,8 @@ end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(init(out)), param(zeros(out)))
|
||||
RNNCell(σ, init(out, in), init(out, out),
|
||||
init(out), zeros(out))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
@ -107,8 +107,8 @@ end
|
||||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
|
||||
param(zeros(out)), param(zeros(out)))
|
||||
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
||||
zeros(out), zeros(out))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
@ -153,8 +153,8 @@ mutable struct GRUCell{A,V}
|
||||
end
|
||||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(init(out*3)), param(zeros(out)))
|
||||
GRUCell(init(out * 3, in), init(out * 3, out),
|
||||
init(out * 3), zeros(out))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@. v = ρ*v - η*Δ
|
||||
@. Δ = -d
|
||||
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
@ -323,5 +323,5 @@ WeightDecay() = WeightDecay(0)
|
||||
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * data(x)
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
@ -51,7 +51,7 @@ function loadparams!(m, xs)
|
||||
for (p, x) in zip(params(m), xs)
|
||||
size(p) == size(x) ||
|
||||
error("Expected param size $(size(p)), got $(size(x))")
|
||||
copyto!(data(p), data(x))
|
||||
copyto!(p, x)
|
||||
end
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user