Merge pull request #441 from Paethon/rm_initn
Removes initn initialization
This commit is contained in:
commit
30486f9c03
|
@ -30,8 +30,8 @@ Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
|
@ -110,9 +110,9 @@ Max pooling layer. `k` stands for the size of the window for each dimension of t
|
|||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct MaxPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux
|
||||
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal
|
||||
using StatsBase: std
|
||||
using Random
|
||||
using Test
|
||||
|
@ -64,10 +64,6 @@ end
|
|||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
Random.seed!(0)
|
||||
# initn() should yield a kernel with stddev ~= 1e-2
|
||||
v = initn(10, 10)
|
||||
@test std(v) > 0.9*1e-2
|
||||
@test std(v) < 1.1*1e-2
|
||||
|
||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||
|
|
Loading…
Reference in New Issue