layer wip
This commit is contained in:
parent
0bf22dfb8e
commit
9d0dd9fb7e
@ -7,13 +7,14 @@ module Flux
|
|||||||
using Juno, Requires
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM, Conv2D,
|
||||||
Dropout, LayerNorm, BatchNorm,
|
Dropout, LayerNorm, BatchNorm,
|
||||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
|
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
|
||||||
|
conv2d, maxpool2d, avgpool2d
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
@ -27,6 +28,7 @@ include("treelike.jl")
|
|||||||
|
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
|
include("layers/conv.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
|
|
||||||
|
14
src/layers/conv.jl
Normal file
14
src/layers/conv.jl
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
struct Conv2D{F,A}
|
||||||
|
σ::F
|
||||||
|
weight::A
|
||||||
|
stride::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
|
init = initn, stride = 1) =
|
||||||
|
Conv2D(σ, param(initn(k..., ch...)), stride)
|
||||||
|
|
||||||
|
Flux.treelike(Conv2D)
|
||||||
|
|
||||||
|
# (c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride))
|
||||||
|
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight))
|
Loading…
Reference in New Issue
Block a user