removed most error, the only one in Fallbacks test persits

This commit is contained in:
pevnak 2018-07-19 10:58:43 +02:00 committed by Mike J Innes
parent c657d4e47f
commit 926411a449
4 changed files with 22 additions and 32 deletions

View File

@ -4,7 +4,7 @@ module Flux
# Zero Flux Given
using MacroTools, Juno, Requires, Reexport, StatsBase, Random
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv,

View File

@ -108,9 +108,9 @@ mutable struct BatchNorm{F,V,W,N}
end
BatchNorm(chs::Integer, λ = identity;
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
initβ = (i) -> fill(0.0,i), initγ = (i) -> fill(1.0,i), ϵ = 1e-8, momentum = .1) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zero(chs), ones(chs), ϵ, momentum, true)
fill(0.0,chs), fill(1.0,chs), ϵ, momentum, true)
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||

View File

@ -206,15 +206,10 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
# Reductions
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
# Base.sum(xs::TrackedArray) = track(sum, xs)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
# @grad sum(xs, dims::Int) = sum(data(xs), dims = dims),
# Δ -> (zero(xs) .+ Δ, nothing)
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
Δ -> (zero(xs) .+ Δ, )
# @grad sum(xs) = sum(data(xs)),
# Δ -> (zero(xs) .+ Δ,)
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
@ -228,13 +223,10 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Statistics.mean(xs::TrackedArray) = track(mean, xs)
Statistics.mean(xs::TrackedArray, region) = track(mean, xs, region)
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
Base.maximum(xs::TrackedArray) = track(maximum, xs)
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
Base.minimum(xs::TrackedArray) = track(minimum, xs)
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
import LinearAlgebra: dot
@ -244,35 +236,33 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
using StatsBase
# Hacks to get std working
StatsBase.std(x::TrackedArray; mean = Statistics.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
StatsBase.std(x::TrackedArray, dim; mean = Statistics.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
@grad mean(xs, region) = mean(data(xs), dims=region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
@grad function maximum(xs, r...)
maximum(data(xs), r...), function (Δ)
@grad function maximum(xs; dims = dims)
maximum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
_, i = findmax(data(xs), r...)
_, i = findmax(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
return (nobacksies(:maximum, Δ′),)
end
end
@grad function minimum(xs, r...)
minimum(data(xs), r...), function (Δ)
@grad function minimum(xs; dims = dims)
minimum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
_, i = findmin(data(xs), r...)
_, i = findmin(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
return (nobacksies(:minimum, Δ′),)
end
end

View File

@ -4,7 +4,7 @@ using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean
using Statistics: mean, std
# using StatsBase
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -161,7 +161,7 @@ end
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest(x -> std(x, dims = 1), rand(5,5))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))