Merge pull request #197 from chengchingwen/master

Implement `prod` for `TrackedArray`
This commit is contained in:
Mike J Innes 2018-03-15 15:17:24 +00:00 committed by GitHub
commit 5d7edb5aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 0 deletions

View File

@ -139,6 +139,13 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)

View File

@ -16,6 +16,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))