parent
2b80573248
commit
af96a197c1
|
@ -7,9 +7,6 @@ nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) #In case of conv
|
||||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
||||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
||||||
|
|
||||||
he_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / last(nfan(dims...)))
|
|
||||||
he_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / last(nfan(dims...)))
|
|
||||||
|
|
||||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||||
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
using Flux
|
using Flux
|
||||||
using Flux: throttle, nfan, glorot_uniform, glorot_normal, he_uniform, he_normal,
|
using Flux: throttle, nfan, glorot_uniform, glorot_normal, stack, unstack
|
||||||
stack, unstack
|
|
||||||
using StatsBase: var
|
using StatsBase: var
|
||||||
using Random
|
using Random
|
||||||
using Test
|
using Test
|
||||||
|
@ -78,18 +77,6 @@ end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "he" begin
|
|
||||||
# he_uniform and he_normal should both yield a kernel with variance ≈ 2/fan_in
|
|
||||||
for dims ∈ [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
|
|
||||||
for init ∈ [he_uniform, he_normal]
|
|
||||||
v = init(dims...)
|
|
||||||
fan_in, fan_out = nfan(dims...)
|
|
||||||
σ2 = 2 / fan_out
|
|
||||||
@test 0.9σ2 < var(v) < 1.1σ2
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Params" begin
|
@testset "Params" begin
|
||||||
|
|
Loading…
Reference in New Issue