From 63862c2324bb24e08c37b51643ab4941dec6dfc7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 15 Feb 2018 20:52:29 +0000 Subject: [PATCH] easier initialisation with weights --- src/layers/basic.jl | 6 ++++-- src/layers/conv.jl | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index b3b09270..f93e6818 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -59,14 +59,16 @@ Tracked 2-element Array{Float64,1}: ``` """ struct Dense{F,S,T} - σ::F W::S b::T + σ::F end +Dense(W, b) = Dense(W, b, identity) + function Dense(in::Integer, out::Integer, σ = identity; initW = glorot_uniform, initb = zeros) - return Dense(σ, param(initW(out, in)), param(initb(out))) + return Dense(param(initW(out, in)), param(initb(out)), σ) end treelike(Dense) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 74b3b86a..c94b642b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -18,9 +18,13 @@ struct Conv2D{F,A,V} pad::Int end +Conv2D(w::AbstractArray{T,4}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0) where T = + Conv2D(σ, w, b, stride, pad) + Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0) = - Conv2D(σ, param(init(k..., ch...)), param(zeros(ch[2])), stride, pad) + Conv2D(param(init(k..., ch...)), param(zeros(ch[2])), σ, stride = stride, pad = pad) Flux.treelike(Conv2D)