layer wip

This commit is contained in:
Mike J Innes 2017-12-15 13:22:57 +00:00
parent 0bf22dfb8e
commit 9d0dd9fb7e
2 changed files with 18 additions and 2 deletions

View File

@ -7,13 +7,14 @@ module Flux
using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM,
export Chain, Dense, RNN, LSTM, Conv2D,
Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves
using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl")
using .Tracker
@ -27,6 +28,7 @@ include("treelike.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalisation.jl")

14
src/layers/conv.jl Normal file
View File

@ -0,0 +1,14 @@
struct Conv2D{F,A}
σ::F
weight::A
stride::Int
end
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = initn, stride = 1) =
Conv2D(σ, param(initn(k..., ch...)), stride)
Flux.treelike(Conv2D)
# (c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride))
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight))