updated sum to be compliant with latest beta. Removed some depwarns

This commit is contained in:
pevnak 2018-07-17 16:57:39 +02:00 committed by Mike J Innes
parent e5b3d27016
commit e98538673a
4 changed files with 28 additions and 20 deletions

View File

@ -28,7 +28,7 @@ children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
@ -101,7 +101,7 @@ struct Diagonal{T}
β::T
end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(in::Integer; initα = ones, initβ = (x) -> similar(x) .= 0) =
Diagonal(param(initα(in)), param(initβ(in)))
@treelike Diagonal

View File

@ -1,6 +1,7 @@
import Base: *, ==
import LinearAlgebra
using Statistics
using LinearAlgebra: Transpose, Adjoint, diagm
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
@ -26,7 +27,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, similar(x) .= 0)
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
@ -204,12 +205,16 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
# Reductions
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
Base.sum(xs::TrackedArray; dims) = track(sum, xs, dims)
Base.sum(xs::TrackedArray) = track(sum, xs)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
@grad sum(xs, dim...) = sum(data(xs), dim...),
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
@grad sum(xs, dims::Int) = sum(data(xs), dims = dims),
Δ -> (zero(xs) .+ Δ, nothing)
@grad sum(xs, dims) = sum(data(xs), dims = dims),
Δ -> (zero(xs) .+ Δ, map(_->nothing,dims)...)
@grad sum(xs) = sum(data(xs)),
Δ -> (zero(xs) .+ Δ,)
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
@ -223,8 +228,8 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
Statistics.mean(xs::TrackedArray) = track(mean, xs)
Statistics.mean(xs::TrackedArray, region) = track(mean, xs, region)
Base.maximum(xs::TrackedArray) = track(maximum, xs)
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
@ -242,9 +247,9 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
using StatsBase
# Hacks to get std working
StatsBase.std(x::TrackedArray; mean = Base.mean(x)) =
StatsBase.std(x::TrackedArray; mean = Statistics.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
StatsBase.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
StatsBase.std(x::TrackedArray, dim; mean = Statistics.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =
@ -349,7 +354,7 @@ dualify(xs::Real, ps) = Dual(xs, ps)
unbroadcast(x::Tuple, Δ) =
x == size(Δ) ? Δ :
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
unbroadcast(x::Tuple{}, Δ) = sum(Δ)

View File

@ -1,8 +1,8 @@
# Arrays
initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))

View File

@ -14,10 +14,13 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
@test gradtest(x -> sum(x), randn(Float64,2,3))
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@ -167,9 +170,9 @@ end
2y + x
end
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
@ -213,7 +216,7 @@ end
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
b = param(rand())
Tracker.back!(b)