updated sum to be compliant with latest beta. Removed some depwarns
This commit is contained in:
parent
e5b3d27016
commit
e98538673a
|
@ -28,7 +28,7 @@ children(c::Chain) = c.layers
|
|||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
|
@ -101,7 +101,7 @@ struct Diagonal{T}
|
|||
β::T
|
||||
end
|
||||
|
||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(in::Integer; initα = ones, initβ = (x) -> similar(x) .= 0) =
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
|
||||
@treelike Diagonal
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import Base: *, ==
|
||||
|
||||
import LinearAlgebra
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm
|
||||
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
|
@ -26,7 +27,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, similar(x) .= 0)
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
|
@ -204,12 +205,16 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray; dims) = track(sum, xs, dims)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
||||
@grad sum(xs, dims::Int) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, nothing)
|
||||
@grad sum(xs, dims) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dims)...)
|
||||
@grad sum(xs) = sum(data(xs)),
|
||||
Δ -> (zero(xs) .+ Δ,)
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
|
@ -223,8 +228,8 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
|||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
Statistics.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Statistics.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
|
@ -242,9 +247,9 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
|||
using StatsBase
|
||||
|
||||
# Hacks to get std working
|
||||
StatsBase.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
StatsBase.std(x::TrackedArray; mean = Statistics.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
StatsBase.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
StatsBase.std(x::TrackedArray, dim; mean = Statistics.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
|
||||
LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
|
@ -349,7 +354,7 @@ dualify(xs::Real, ps) = Dual(xs, ps)
|
|||
|
||||
unbroadcast(x::Tuple, Δ) =
|
||||
x == size(Δ) ? Δ :
|
||||
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
|
||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
|
|
|
@ -14,10 +14,13 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
|
@ -167,9 +170,9 @@ end
|
|||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
@ -213,7 +216,7 @@ end
|
|||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
|
|
Loading…
Reference in New Issue