layer wip

This commit is contained in:
Mike J Innes 2017-12-15 13:22:57 +00:00
parent 0bf22dfb8e
commit 9d0dd9fb7e
2 changed files with 18 additions and 2 deletions

View File

@ -7,13 +7,14 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, export Chain, Dense, RNN, LSTM, Conv2D,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves param, params, mapleaves
using NNlib using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker
@ -27,6 +28,7 @@ include("treelike.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/basic.jl") include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl") include("layers/recurrent.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")

14
src/layers/conv.jl Normal file
View File

@ -0,0 +1,14 @@
struct Conv2D{F,A}
σ::F
weight::A
stride::Int
end
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = initn, stride = 1) =
Conv2D(σ, param(initn(k..., ch...)), stride)
Flux.treelike(Conv2D)
# (c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride))
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight))