recurrent bug fixes

This commit is contained in:
Dominique Luna 2018-08-18 11:50:52 -04:00
parent 23af487f98
commit 3f42301e07

View File

@ -1,4 +1,4 @@
gate(h, n) = (1:h) + h*(n-1) gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)] gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
@ -122,9 +122,9 @@ end
function LSTMCell(in::Integer, out::Integer; function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform) init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)), cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(initn(out)), param(initn(out))) param(initn(out)), param(initn(out)))
cell.b.data[gate(out, 2)] = 1 cell.b.data[gate(out, 2)] .= 1
return cell return cell
end end
@ -170,7 +170,7 @@ end
GRUCell(in, out; init = glorot_uniform) = GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)), GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zero(out*3)), param(initn(out))) param(zeros(out*3)), param(initn(out)))
function (m::GRUCell)(h, x) function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1) b, o = m.b, size(h, 1)