recurrent bug fixes
This commit is contained in:
parent
23af487f98
commit
3f42301e07
@ -1,4 +1,4 @@
|
|||||||
gate(h, n) = (1:h) + h*(n-1)
|
gate(h, n) = (1:h) .+ h*(n-1)
|
||||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||||
|
|
||||||
@ -122,9 +122,9 @@ end
|
|||||||
|
|
||||||
function LSTMCell(in::Integer, out::Integer;
|
function LSTMCell(in::Integer, out::Integer;
|
||||||
init = glorot_uniform)
|
init = glorot_uniform)
|
||||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)),
|
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||||
param(initn(out)), param(initn(out)))
|
param(initn(out)), param(initn(out)))
|
||||||
cell.b.data[gate(out, 2)] = 1
|
cell.b.data[gate(out, 2)] .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ end
|
|||||||
|
|
||||||
GRUCell(in, out; init = glorot_uniform) =
|
GRUCell(in, out; init = glorot_uniform) =
|
||||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||||
param(zero(out*3)), param(initn(out)))
|
param(zeros(out*3)), param(initn(out)))
|
||||||
|
|
||||||
function (m::GRUCell)(h, x)
|
function (m::GRUCell)(h, x)
|
||||||
b, o = m.b, size(h, 1)
|
b, o = m.b, size(h, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user