removed most error, the only one in Fallbacks test persits
This commit is contained in:
parent
c657d4e47f
commit
926411a449
|
@ -4,7 +4,7 @@ module Flux
|
|||
|
||||
# Zero Flux Given
|
||||
|
||||
using MacroTools, Juno, Requires, Reexport, StatsBase, Random
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
|
|
|
@ -108,9 +108,9 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
||||
initβ = (i) -> fill(0.0,i), initγ = (i) -> fill(1.0,i), ϵ = 1e-8, momentum = .1) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zero(chs), ones(chs), ϵ, momentum, true)
|
||||
fill(0.0,chs), fill(1.0,chs), ϵ, momentum, true)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
|
|
|
@ -206,15 +206,10 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
# Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
# @grad sum(xs, dims::Int) = sum(data(xs), dims = dims),
|
||||
# Δ -> (zero(xs) .+ Δ, nothing)
|
||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, )
|
||||
# @grad sum(xs) = sum(data(xs)),
|
||||
# Δ -> (zero(xs) .+ Δ,)
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
|
@ -228,13 +223,10 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
|||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Statistics.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Statistics.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
||||
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||
|
||||
import LinearAlgebra: dot
|
||||
|
||||
|
@ -244,35 +236,33 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
|||
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
using StatsBase
|
||||
|
||||
# Hacks to get std working
|
||||
StatsBase.std(x::TrackedArray; mean = Statistics.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
StatsBase.std(x::TrackedArray, dim; mean = Statistics.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
||||
@grad mean(xs, region) = mean(data(xs), dims=region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
||||
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||
|
||||
@grad function maximum(xs, r...)
|
||||
maximum(data(xs), r...), function (Δ)
|
||||
@grad function maximum(xs; dims = dims)
|
||||
maximum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), r...)
|
||||
_, i = findmax(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:maximum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
@grad function minimum(xs, r...)
|
||||
minimum(data(xs), r...), function (Δ)
|
||||
@grad function minimum(xs; dims = dims)
|
||||
minimum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), r...)
|
||||
_, i = findmin(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:minimum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
|||
using NNlib: conv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean
|
||||
using Statistics: mean, std
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
|
@ -161,7 +161,7 @@ end
|
|||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
|
Loading…
Reference in New Issue