recurrent bug fixes

This commit is contained in:
Dominique Luna 2018-08-18 11:50:52 -04:00
parent 23af487f98
commit 3f42301e07

View File

@ -1,4 +1,4 @@
gate(h, n) = (1:h) + h*(n-1)
gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
@ -122,9 +122,9 @@ end
function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)),
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(initn(out)), param(initn(out)))
cell.b.data[gate(out, 2)] = 1
cell.b.data[gate(out, 2)] .= 1
return cell
end
@ -170,7 +170,7 @@ end
GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zero(out*3)), param(initn(out)))
param(zeros(out*3)), param(initn(out)))
function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1)