Resolving Merge Conflicts
This commit is contained in:
commit
c7c0ee2cbc
29
CITATION.bib
Normal file
29
CITATION.bib
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
@article{Flux.jl-2018,
|
||||||
|
author = {Michael Innes and
|
||||||
|
Elliot Saba and
|
||||||
|
Keno Fischer and
|
||||||
|
Dhairya Gandhi and
|
||||||
|
Marco Concetto Rudilosso and
|
||||||
|
Neethu Mariya Joy and
|
||||||
|
Tejan Karmali and
|
||||||
|
Avik Pal and
|
||||||
|
Viral Shah},
|
||||||
|
title = {Fashionable Modelling with Flux},
|
||||||
|
journal = {CoRR},
|
||||||
|
volume = {abs/1811.01457},
|
||||||
|
year = {2018},
|
||||||
|
url = {http://arxiv.org/abs/1811.01457},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
eprint = {1811.01457},
|
||||||
|
timestamp = {Thu, 22 Nov 2018 17:58:30 +0100},
|
||||||
|
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1811-01457},
|
||||||
|
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{innes:2018,
|
||||||
|
author = {Mike Innes},
|
||||||
|
title = {Flux: Elegant Machine Learning with Julia},
|
||||||
|
journal = {Journal of Open Source Software},
|
||||||
|
year = {2018},
|
||||||
|
doi = {10.21105/joss.00602},
|
||||||
|
}
|
6
NEWS.md
6
NEWS.md
@ -1,5 +1,10 @@
|
|||||||
|
# v0.9.0
|
||||||
|
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
|
||||||
|
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
|
||||||
|
|
||||||
# v0.8.0
|
# v0.8.0
|
||||||
|
|
||||||
|
* [Dropout now has a `dims` argument for specifying the unbroadcast dimensions.](https://github.com/FluxML/Flux.jl/pull/563)
|
||||||
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
|
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
|
||||||
* New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647)
|
* New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647)
|
||||||
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
|
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
|
||||||
@ -13,6 +18,7 @@
|
|||||||
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
|
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
|
||||||
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||||
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
|
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
|
||||||
|
* New [CrossCor](https://github.com/FluxML/Flux.jl/pull/762).
|
||||||
|
|
||||||
AD Changes:
|
AD Changes:
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ MaxPool
|
|||||||
MeanPool
|
MeanPool
|
||||||
DepthwiseConv
|
DepthwiseConv
|
||||||
ConvTranspose
|
ConvTranspose
|
||||||
|
CrossCor
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recurrent Layers
|
## Recurrent Layers
|
||||||
@ -36,17 +37,7 @@ But in contrast to the layers described in the other sections are not readily gr
|
|||||||
|
|
||||||
```@docs
|
```@docs
|
||||||
Maxout
|
Maxout
|
||||||
```
|
SkipConnection
|
||||||
|
|
||||||
# Normalisation & Regularisation
|
|
||||||
|
|
||||||
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
|
||||||
|
|
||||||
```@docs
|
|
||||||
Flux.testmode!
|
|
||||||
BatchNorm
|
|
||||||
Dropout
|
|
||||||
LayerNorm
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Activation Functions
|
## Activation Functions
|
||||||
|
@ -18,7 +18,7 @@ And you use less memory.
|
|||||||
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||||
they should also preserve the type of their inputs.
|
they should also preserve the type of their inputs.
|
||||||
|
|
||||||
A very artificial example using an activatioon function like
|
A very artificial example using an activation function like
|
||||||
|
|
||||||
```
|
```
|
||||||
my_tanh(x) = Float64(tanh(x))
|
my_tanh(x) = Float64(tanh(x))
|
||||||
@ -73,4 +73,4 @@ end
|
|||||||
```
|
```
|
||||||
|
|
||||||
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
|
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
|
||||||
This will avoid the splatting penality, and will hit the optimised `reduce` method.
|
This will avoid the splatting penalty, and will hit the optimised `reduce` method.
|
||||||
|
@ -10,7 +10,7 @@ using Zygote: Params, @adjoint, gradient
|
|||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
params, mapleaves, cpu, gpu, f32, f64, param, data
|
SkipConnection,params, mapleaves, cpu, gpu, f32, f64, param, data
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
|
@ -190,5 +190,37 @@ function (mo::Maxout)(input::AbstractArray)
|
|||||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
SkipConnection(layers...)
|
||||||
|
|
||||||
|
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
|
||||||
|
and a shortcut connection linking the input to the block to the
|
||||||
|
output through a user-supplied callable.
|
||||||
|
|
||||||
|
`SkipConnection` requires the output dimension to be the same as the input.
|
||||||
|
|
||||||
|
A 'ResNet'-type skip-connection with identity shortcut would simply be
|
||||||
|
```julia
|
||||||
|
SkipConnection(layer, (a,b) -> a + b)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
struct SkipConnection
|
||||||
|
layers
|
||||||
|
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
||||||
|
end
|
||||||
|
|
||||||
|
@treelike SkipConnection
|
||||||
|
|
||||||
|
function (skip::SkipConnection)(input)
|
||||||
|
#We apply the layers to the input and return the result of the application of the layers and the original input
|
||||||
|
skip.connection(skip.layers(input), input)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Base.show(io::IO, b::SkipConnection)
|
||||||
|
print(io, "SkipConnection(")
|
||||||
|
join(io, b.layers, ", ")
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
param(x) = x
|
param(x) = x
|
||||||
data(x) = x
|
data(x) = x
|
||||||
|
@ -14,11 +14,11 @@ Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
|
|||||||
|
|
||||||
size = (2,2)
|
size = (2,2)
|
||||||
in = 1
|
in = 1
|
||||||
out = 16
|
out = 16
|
||||||
Conv((2, 2), 1=>16, relu)
|
Conv((2, 2), 1=>16, relu)
|
||||||
|
|
||||||
Data should be stored in WHCN order (width, height, # channels, # batches).
|
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||||
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
and a batch of 50 would be a `100×100×3×50` array.
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
@ -136,18 +136,17 @@ end
|
|||||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
"""
|
"""
|
||||||
DepthwiseConv(size, in)
|
DepthwiseConv(size, in=>out)
|
||||||
DepthwiseConv(size, in=>mul)
|
DepthwiseConv(size, in=>out, relu)
|
||||||
DepthwiseConv(size, in=>mul, relu)
|
|
||||||
|
|
||||||
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
`in` and `mul` specify the number of input channels and channel multiplier respectively.
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
In case the `mul` is not specified it is taken as 1.
|
Note that `out` must be an integer multiple of `in`.
|
||||||
|
|
||||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad` and `stride`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
struct DepthwiseConv{N,M,F,A,V}
|
struct DepthwiseConv{N,M,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
@ -187,8 +186,8 @@ function (c::DepthwiseConv)(x)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::DepthwiseConv)
|
function Base.show(io::IO, l::DepthwiseConv)
|
||||||
print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2])
|
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
|
||||||
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
|
print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
|
||||||
l.σ == identity || print(io, ", ", l.σ)
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
@ -198,6 +197,76 @@ end
|
|||||||
|
|
||||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
"""
|
||||||
|
CrossCor(size, in=>out)
|
||||||
|
CrossCor(size, in=>out, relu)
|
||||||
|
|
||||||
|
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
|
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
|
||||||
|
giving us a 16-channel output. Output is activated with ReLU.
|
||||||
|
|
||||||
|
size = (2,2)
|
||||||
|
in = 1
|
||||||
|
out = 16
|
||||||
|
CrossCor((2, 2), 1=>16, relu)
|
||||||
|
|
||||||
|
Data should be stored in WHCN order (width, height, # channels, # batches).
|
||||||
|
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
|
||||||
|
and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
|
"""
|
||||||
|
struct CrossCor{N,M,F,A,V}
|
||||||
|
σ::F
|
||||||
|
weight::A
|
||||||
|
bias::V
|
||||||
|
stride::NTuple{N,Int}
|
||||||
|
pad::NTuple{M,Int}
|
||||||
|
dilation::NTuple{N,Int}
|
||||||
|
end
|
||||||
|
|
||||||
|
function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
|
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||||
|
stride = expand(Val(N-2), stride)
|
||||||
|
pad = expand(Val(2*(N-2)), pad)
|
||||||
|
dilation = expand(Val(N-2), dilation)
|
||||||
|
return CrossCor(σ, w, b, stride, pad, dilation)
|
||||||
|
end
|
||||||
|
|
||||||
|
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||||
|
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||||
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
|
@treelike CrossCor
|
||||||
|
|
||||||
|
function crosscor(x, w, ddims::DenseConvDims)
|
||||||
|
ddims = DenseConvDims(ddims, F=true)
|
||||||
|
return conv(x, w, ddims)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (c::CrossCor)(x::AbstractArray)
|
||||||
|
# TODO: breaks gpu broadcast :(
|
||||||
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
|
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||||
|
σ.(crosscor(x, c.weight, cdims) .+ b)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::CrossCor)
|
||||||
|
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
|
||||||
|
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
|
||||||
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
|
||||||
|
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
invoke(a, Tuple{AbstractArray}, x)
|
||||||
|
|
||||||
|
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
a(T.(x))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MaxPool(k)
|
MaxPool(k)
|
||||||
|
@ -3,11 +3,12 @@ istraining() = false
|
|||||||
@adjoint istraining() = true, _ -> nothing
|
@adjoint istraining() = true, _ -> nothing
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dropout(p)
|
Dropout(p, dims = :)
|
||||||
|
|
||||||
A Dropout layer. For each input, either sets that input to `0` (with probability
|
A Dropout layer. For each input, either sets that input to `0` (with probability
|
||||||
`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it
|
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
|
||||||
reduces overfitting during training.
|
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
|
||||||
|
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
|
||||||
|
|
||||||
Does nothing to the input once in [`testmode!`](@ref).
|
Does nothing to the input once in [`testmode!`](@ref).
|
||||||
"""
|
"""
|
||||||
@ -19,13 +20,16 @@ mutable struct Dropout{F}
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
_dropout_shape(s, ::Colon) = size(s)
|
||||||
|
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||||
|
|
||||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||||
|
|
||||||
function (a::Dropout)(x)
|
function (a::Dropout)(x)
|
||||||
istraining() || return x
|
istraining() || return x
|
||||||
y = similar(x)
|
y = similar(x)
|
||||||
rand!(y)
|
rand!(y)
|
||||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
y .= _dropout_kernel.(y, p, 1 - p)
|
||||||
return x .* y
|
return x .* y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4))
|
|||||||
l = c(gpu(rand(10,10,3,2)))
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
Flux.back!(sum(l))
|
Flux.back!(sum(l))
|
||||||
|
|
||||||
|
c = gpu(CrossCor((2,2),3=>4))
|
||||||
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
|
Flux.back!(sum(l))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "onecold gpu" begin
|
@testset "onecold gpu" begin
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
|
trainmode(f, x...) = forward(f, x...)[1]
|
||||||
|
#
|
||||||
# @testset "CUDNN BatchNorm" begin
|
# @testset "CUDNN BatchNorm" begin
|
||||||
# @testset "4D Input" begin
|
# @testset "4D Input" begin
|
||||||
# x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
|
# x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
|
||||||
# m = BatchNorm(3)
|
# m = BatchNorm(3)
|
||||||
# cx = gpu(x)
|
# cx = gpu(x)
|
||||||
# cm = gpu(m)
|
# cm = gpu(m)
|
||||||
#
|
#
|
||||||
# y = m(x)
|
# y = trainmode(m, x)
|
||||||
# cy = cm(cx)
|
# cy = trainmode(cm, cx)
|
||||||
#
|
#
|
||||||
# @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
# # @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||||
#
|
#
|
||||||
# @test cpu(data(cy)) ≈ data(y)
|
# @test cpu(data(cy)) ≈ data(y)
|
||||||
#
|
#
|
||||||
|
@ -72,4 +72,16 @@ import Flux: activations
|
|||||||
@test length(ps) == 8 #4 alts, each with weight and bias
|
@test length(ps) == 8 #4 alts, each with weight and bias
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "SkipConnection" begin
|
||||||
|
@testset "zero sum" begin
|
||||||
|
input = randn(10, 10, 10, 10)
|
||||||
|
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "concat size" begin
|
||||||
|
input = randn(10, 2)
|
||||||
|
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
@ -39,20 +39,14 @@ end
|
|||||||
|
|
||||||
@testset "Depthwise Conv" begin
|
@testset "Depthwise Conv" begin
|
||||||
r = zeros(Float32, 28, 28, 3, 5)
|
r = zeros(Float32, 28, 28, 3, 5)
|
||||||
m1 = DepthwiseConv((2, 2), 3=>5)
|
m1 = DepthwiseConv((2, 2), 3=>15)
|
||||||
@test size(m1(r), 3) == 15
|
@test size(m1(r), 3) == 15
|
||||||
m2 = DepthwiseConv((2, 2), 3)
|
|
||||||
@test size(m2(r), 3) == 3
|
|
||||||
|
|
||||||
x = zeros(Float64, 28, 28, 3, 5)
|
m3 = DepthwiseConv((2, 3), 3=>9)
|
||||||
|
@test size(m3(r), 3) == 9
|
||||||
|
|
||||||
m3 = DepthwiseConv((2, 2), 3 => 5)
|
# Test that we cannot ask for non-integer multiplication factors
|
||||||
|
@test_throws AssertionError DepthwiseConv((2,2), 3=>10)
|
||||||
@test size(m3(r), 3) == 15
|
|
||||||
|
|
||||||
m4 = DepthwiseConv((2, 2), 3)
|
|
||||||
|
|
||||||
@test size(m4(r), 3) == 3
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "ConvTranspose" begin
|
@testset "ConvTranspose" begin
|
||||||
@ -61,3 +55,50 @@ end
|
|||||||
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||||
@test size(x_hat) == size(x)
|
@test size(x_hat) == size(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "CrossCor" begin
|
||||||
|
x = rand(Float32, 28, 28, 1, 1)
|
||||||
|
w = rand(2,2,1,1)
|
||||||
|
y = CrossCor(w, [0.0])
|
||||||
|
|
||||||
|
@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]
|
||||||
|
|
||||||
|
r = zeros(Float32, 28, 28, 1, 5)
|
||||||
|
m = Chain(
|
||||||
|
CrossCor((2, 2), 1=>16, relu),
|
||||||
|
MaxPool((2,2)),
|
||||||
|
CrossCor((2, 2), 16=>8, relu),
|
||||||
|
MaxPool((2,2)),
|
||||||
|
x -> reshape(x, :, size(x, 4)),
|
||||||
|
Dense(288, 10), softmax)
|
||||||
|
|
||||||
|
@test size(m(r)) == (10, 5)
|
||||||
|
@test y(x) != Conv(w, [0.0])(x)
|
||||||
|
@test CrossCor(w[end:-1:1, end:-1:1, :, :], [0.0])(x) == Conv(w, [0.0])(x)
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Conv with non quadratic window #700" begin
|
||||||
|
data = zeros(Float32, 7,7,1,1)
|
||||||
|
data[4,4,1,1] = 1
|
||||||
|
|
||||||
|
l = Conv((3,3), 1=>1)
|
||||||
|
expected = zeros(eltype(l.weight),5,5,1,1)
|
||||||
|
expected[2:end-1,2:end-1,1,1] = l.weight
|
||||||
|
@test expected == l(data)
|
||||||
|
|
||||||
|
l = Conv((3,1), 1=>1)
|
||||||
|
expected = zeros(eltype(l.weight),5,7,1,1)
|
||||||
|
expected[2:end-1,4,1,1] = l.weight
|
||||||
|
@test expected == l(data)
|
||||||
|
|
||||||
|
l = Conv((1,3), 1=>1)
|
||||||
|
expected = zeros(eltype(l.weight),7,5,1,1)
|
||||||
|
expected[4,2:end-1,1,1] = l.weight
|
||||||
|
@test expected == l(data)
|
||||||
|
|
||||||
|
@test begin
|
||||||
|
# we test that the next expression does not throw
|
||||||
|
randn(Float32, 10,10,1,1) |> Conv((6,1), 1=>1, Flux.σ)
|
||||||
|
true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -25,6 +25,16 @@ trainmode(f, x...) = forward(f, x...)[1]
|
|||||||
@test count(a->a == 0, y) > 50
|
@test count(a->a == 0, y) > 50
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a == 0, y) == 0
|
@test count(a->a == 0, y) == 0
|
||||||
|
|
||||||
|
x = rand(100, 50)
|
||||||
|
m = Dropout(0.5, dims = 2)
|
||||||
|
y = m(x)
|
||||||
|
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
|
||||||
|
@test minimum(c) == maximum(c)
|
||||||
|
m = Dropout(0.5, dims = 1)
|
||||||
|
y = m(x)
|
||||||
|
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
|
||||||
|
@test minimum(c) == maximum(c)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "BatchNorm" begin
|
@testset "BatchNorm" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user