Merge pull request #356 from domluna/recurrent-fix

recurrent bug fixes
This commit is contained in:
Mike J Innes 2018-08-20 11:27:49 +01:00 committed by GitHub
commit 0ef6456903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
gate(h, n) = (1:h) + h*(n-1)
gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
@ -84,7 +84,7 @@ end
RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(initn(out)))
param(zeros(out)), param(init(out)))
function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
@ -122,9 +122,9 @@ end
function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)),
param(initn(out)), param(initn(out)))
cell.b.data[gate(out, 2)] = 1
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(init(out)), param(init(out)))
cell.b.data[gate(out, 2)] .= 1
return cell
end
@ -170,7 +170,7 @@ end
GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zero(out*3)), param(initn(out)))
param(zeros(out*3)), param(init(out)))
function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1)