diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 969a777e..15a590b7 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -84,7 +84,7 @@ end RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = RNNCell(σ, param(init(out, in)), param(init(out, out)), - param(zero(out)), param(initn(out))) + param(fill(0.0,out)), param(initn(out))) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 7c23288f..818e5e73 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -205,16 +205,16 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) # Reductions -Base.sum(xs::TrackedArray; dims) = track(sum, xs, dims) -Base.sum(xs::TrackedArray) = track(sum, xs) +Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims) +# Base.sum(xs::TrackedArray) = track(sum, xs) Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) -@grad sum(xs, dims::Int) = sum(data(xs), dims = dims), - Δ -> (zero(xs) .+ Δ, nothing) -@grad sum(xs, dims) = sum(data(xs), dims = dims), - Δ -> (zero(xs) .+ Δ, map(_->nothing,dims)...) -@grad sum(xs) = sum(data(xs)), - Δ -> (zero(xs) .+ Δ,) +# @grad sum(xs, dims::Int) = sum(data(xs), dims = dims), + # Δ -> (zero(xs) .+ Δ, nothing) +@grad sum(xs; dims = :) = sum(data(xs), dims = dims), + Δ -> (zero(xs) .+ Δ, ) +# @grad sum(xs) = sum(data(xs)), + # Δ -> (zero(xs) .+ Δ,) Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) Base.prod(xs::TrackedArray) = track(prod, xs)