layer wip
This commit is contained in:
parent
0bf22dfb8e
commit
9d0dd9fb7e
@ -7,13 +7,14 @@ module Flux
|
||||
using Juno, Requires
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
export Chain, Dense, RNN, LSTM, Conv2D,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||
param, params, mapleaves
|
||||
|
||||
using NNlib
|
||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
|
||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
|
||||
conv2d, maxpool2d, avgpool2d
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
@ -27,6 +28,7 @@ include("treelike.jl")
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
include("layers/conv.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
|
||||
|
14
src/layers/conv.jl
Normal file
14
src/layers/conv.jl
Normal file
@ -0,0 +1,14 @@
|
||||
struct Conv2D{F,A}
|
||||
σ::F
|
||||
weight::A
|
||||
stride::Int
|
||||
end
|
||||
|
||||
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = initn, stride = 1) =
|
||||
Conv2D(σ, param(initn(k..., ch...)), stride)
|
||||
|
||||
Flux.treelike(Conv2D)
|
||||
|
||||
# (c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride))
|
||||
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight))
|
Loading…
Reference in New Issue
Block a user