easier initialisation with weights

This commit is contained in:
Mike J Innes 2018-02-15 20:52:29 +00:00
parent 01c31e7fcc
commit 63862c2324
2 changed files with 9 additions and 3 deletions

View File

@ -59,14 +59,16 @@ Tracked 2-element Array{Float64,1}:
```
"""
struct Dense{F,S,T}
σ::F
W::S
b::T
σ::F
end
Dense(W, b) = Dense(W, b, identity)
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
return Dense(σ, param(initW(out, in)), param(initb(out)))
return Dense(param(initW(out, in)), param(initb(out)), σ)
end
treelike(Dense)

View File

@ -18,9 +18,13 @@ struct Conv2D{F,A,V}
pad::Int
end
Conv2D(w::AbstractArray{T,4}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where T =
Conv2D(σ, w, b, stride, pad)
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = initn, stride = 1, pad = 0) =
Conv2D(σ, param(init(k..., ch...)), param(zeros(ch[2])), stride, pad)
Conv2D(param(init(k..., ch...)), param(zeros(ch[2])), σ, stride = stride, pad = pad)
Flux.treelike(Conv2D)