From 3f42301e077c96b1a263da65f628794a98035c5b Mon Sep 17 00:00:00 2001 From: Dominique Luna Date: Sat, 18 Aug 2018 11:50:52 -0400 Subject: [PATCH] recurrent bug fixes --- src/layers/recurrent.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index d97c7fd7..d9c51127 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,4 +1,4 @@ -gate(h, n) = (1:h) + h*(n-1) +gate(h, n) = (1:h) .+ h*(n-1) gate(x::AbstractVector, h, n) = x[gate(h,n)] gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] @@ -122,9 +122,9 @@ end function LSTMCell(in::Integer, out::Integer; init = glorot_uniform) - cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)), + cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)), param(initn(out)), param(initn(out))) - cell.b.data[gate(out, 2)] = 1 + cell.b.data[gate(out, 2)] .= 1 return cell end @@ -170,7 +170,7 @@ end GRUCell(in, out; init = glorot_uniform) = GRUCell(param(init(out*3, in)), param(init(out*3, out)), - param(zero(out*3)), param(initn(out))) + param(zeros(out*3)), param(initn(out))) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1)