add gradient check for `prod` and fix `dims` in `back(::typeof(prod),...)`

This commit is contained in:
chengchingwen 2018-03-07 16:24:44 +08:00
parent 86d782a5ce
commit 7c721475c6
2 changed files with 3 additions and 1 deletions

View File

@ -143,7 +143,7 @@ Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (prod(xs.data) ./ xs.data) .* Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)

View File

@ -16,6 +16,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))