fixed the sum as suggested by mike
This commit is contained in:
parent
02f343d44d
commit
c657d4e47f
@ -84,7 +84,7 @@ end
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zero(out)), param(initn(out)))
|
||||
param(fill(0.0,out)), param(initn(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
|
@ -205,16 +205,16 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray; dims) = track(sum, xs, dims)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
# Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs, dims::Int) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, nothing)
|
||||
@grad sum(xs, dims) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dims)...)
|
||||
@grad sum(xs) = sum(data(xs)),
|
||||
Δ -> (zero(xs) .+ Δ,)
|
||||
# @grad sum(xs, dims::Int) = sum(data(xs), dims = dims),
|
||||
# Δ -> (zero(xs) .+ Δ, nothing)
|
||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, )
|
||||
# @grad sum(xs) = sum(data(xs)),
|
||||
# Δ -> (zero(xs) .+ Δ,)
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
|
Loading…
Reference in New Issue
Block a user