fixed more dep warns, also in tests, but maximum, minimum and size in array.jl still need to be updated. As a result, some more tests may not pass for the time being

This commit is contained in:
Simon Mandlik 2018-07-18 20:20:00 +02:00 committed by Mike J Innes
parent 0471c489e6
commit 02f343d44d
2 changed files with 16 additions and 15 deletions

View File

@ -252,11 +252,11 @@ StatsBase.std(x::TrackedArray; mean = Statistics.mean(x)) =
StatsBase.std(x::TrackedArray, dim; mean = Statistics.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
@grad mean(xs, region) = mean(data(xs), dims = region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
@grad mean(xs, region) = mean(data(xs), dims=region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
@grad function maximum(xs, r...)
maximum(data(xs), r...), function (Δ)
@ -266,6 +266,7 @@ LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
end
end
@grad function minimum(xs, r...)
minimum(data(xs), r...), function (Δ)
Δ′ = zero(xs)

View File

@ -24,7 +24,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
@test gradtest(x -> sum(x), randn(Float64,2,3))
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@ -133,31 +133,31 @@ end
@testset "mean" begin
@test gradtest(mean, rand(2, 3))
@test gradtest(x -> mean(x, 1), rand(2, 3))
@test gradtest(x -> mean(x, 2), rand(2, 3))
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
end
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, 1), rand(2, 3))
@test gradtest(x -> maximum(x, 2), rand(2, 3))
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
end
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, 1), rand(2, 3))
@test gradtest(x -> minimum(x, 2), rand(2, 3))
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))