This commit is contained in:
Mike J Innes 2018-08-03 15:19:10 +01:00
parent 926411a449
commit 7103a0ed7d
3 changed files with 4 additions and 4 deletions

View File

@ -83,7 +83,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzero(T, rnnParamSize(T, d[], input))
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd, x ->

View File

@ -108,9 +108,9 @@ mutable struct BatchNorm{F,V,W,N}
end
BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> fill(0.0,i), initγ = (i) -> fill(1.0,i), ϵ = 1e-8, momentum = .1) =
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
fill(0.0,chs), fill(1.0,chs), ϵ, momentum, true)
zeros(chs), ones(chs), ϵ, momentum, true)
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||

View File

@ -84,7 +84,7 @@ end
RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(fill(0.0,out)), param(initn(out)))
param(zeros(out)), param(initn(out)))
function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b