recurrent bug fixes
This commit is contained in:
parent
23af487f98
commit
3f42301e07
@ -1,4 +1,4 @@
|
||||
gate(h, n) = (1:h) + h*(n-1)
|
||||
gate(h, n) = (1:h) .+ h*(n-1)
|
||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||
|
||||
@ -122,9 +122,9 @@ end
|
||||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)),
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.b.data[gate(out, 2)] = 1
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
|
||||
@ -170,7 +170,7 @@ end
|
||||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zero(out*3)), param(initn(out)))
|
||||
param(zeros(out*3)), param(initn(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user