diff --git a/src/Flux.jl b/src/Flux.jl index 526d6bb8..2acdb177 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,13 +7,14 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, +export Chain, Dense, RNN, LSTM, Conv2D, Dropout, LayerNorm, BatchNorm, SGD, ADAM, Momentum, Nesterov, AMSGrad, param, params, mapleaves using NNlib -export σ, sigmoid, relu, leakyrelu, elu, swish, softmax +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, + conv2d, maxpool2d, avgpool2d include("tracker/Tracker.jl") using .Tracker @@ -27,6 +28,7 @@ include("treelike.jl") include("layers/stateless.jl") include("layers/basic.jl") +include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl new file mode 100644 index 00000000..f7ca6f02 --- /dev/null +++ b/src/layers/conv.jl @@ -0,0 +1,14 @@ +struct Conv2D{F,A} + σ::F + weight::A + stride::Int +end + +Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = initn, stride = 1) = + Conv2D(σ, param(initn(k..., ch...)), stride) + +Flux.treelike(Conv2D) + +# (c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride)) +(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight))