2018-08-23 13:34:11 +00:00
|
|
|
|
using Test, Random
|
2019-03-28 09:07:04 +00:00
|
|
|
|
import Flux: activations
|
|
|
|
|
|
2018-08-23 13:34:11 +00:00
|
|
|
|
@testset "basic" begin
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "helpers" begin
|
|
|
|
|
@testset "activations" begin
|
2019-11-14 22:05:53 +00:00
|
|
|
|
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x))
|
|
|
|
|
x = randn(10)
|
|
|
|
|
@test activations(dummy_model, x)[1] == x.^2
|
|
|
|
|
@test activations(dummy_model, x)[2] == (x.^2 .- 3)
|
|
|
|
|
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3)
|
|
|
|
|
|
|
|
|
|
@test activations(Chain(), x) == ()
|
2019-04-05 10:50:15 +00:00
|
|
|
|
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
|
2019-03-28 09:58:02 +00:00
|
|
|
|
end
|
2019-03-28 14:40:24 +00:00
|
|
|
|
end
|
2019-03-28 11:28:59 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "Chain" begin
|
|
|
|
|
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
|
|
|
|
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
|
|
|
|
# numeric test should be put into testset of corresponding layer
|
|
|
|
|
end
|
2018-08-23 13:34:11 +00:00
|
|
|
|
|
2019-10-09 06:04:31 +00:00
|
|
|
|
@testset "Activations" begin
|
|
|
|
|
c = Chain(Dense(3,5,relu), Dense(5,1,relu))
|
|
|
|
|
X = Float32.([1.0; 1.0; 1.0])
|
|
|
|
|
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c))
|
|
|
|
|
end
|
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "Dense" begin
|
|
|
|
|
@test length(Dense(10, 5)(randn(10))) == 5
|
|
|
|
|
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
|
|
|
|
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
|
|
|
|
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
2018-08-23 13:59:41 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
|
|
|
|
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
|
|
|
|
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
|
|
|
|
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
2018-08-23 13:34:11 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
end
|
2018-08-23 13:34:11 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "Diagonal" begin
|
|
|
|
|
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
|
|
|
|
@test length(Flux.Diagonal(10)(1)) == 10
|
|
|
|
|
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
|
|
|
|
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
2018-08-25 08:30:46 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
|
|
|
|
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
|
|
|
|
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
|
|
|
|
end
|
2019-02-27 12:04:59 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "Maxout" begin
|
|
|
|
|
# Note that the normal common usage of Maxout is as per the docstring
|
|
|
|
|
# These are abnormal constructors used for testing purposes
|
2019-02-27 12:04:59 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "Constructor" begin
|
|
|
|
|
mo = Maxout(() -> identity, 4)
|
|
|
|
|
input = rand(40)
|
|
|
|
|
@test mo(input) == input
|
|
|
|
|
end
|
2019-02-27 12:04:59 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "simple alternatives" begin
|
|
|
|
|
mo = Maxout((x -> x, x -> 2x, x -> 0.5x))
|
|
|
|
|
input = rand(40)
|
|
|
|
|
@test mo(input) == 2*input
|
|
|
|
|
end
|
2019-02-27 12:04:59 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "complex alternatives" begin
|
|
|
|
|
mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x))
|
|
|
|
|
input = [3.0 2.0]
|
|
|
|
|
target = [0.5, 0.7].*input
|
|
|
|
|
@test mo(input) == target
|
|
|
|
|
end
|
2019-03-25 16:02:46 +00:00
|
|
|
|
|
2019-03-28 14:40:24 +00:00
|
|
|
|
@testset "params" begin
|
|
|
|
|
mo = Maxout(()->Dense(32, 64), 4)
|
|
|
|
|
ps = params(mo)
|
|
|
|
|
@test length(ps) == 8 #4 alts, each with weight and bias
|
2019-02-27 12:04:59 +00:00
|
|
|
|
end
|
2019-03-28 14:40:24 +00:00
|
|
|
|
end
|
2019-05-13 13:21:25 +00:00
|
|
|
|
|
|
|
|
|
@testset "SkipConnection" begin
|
|
|
|
|
@testset "zero sum" begin
|
|
|
|
|
input = randn(10, 10, 10, 10)
|
|
|
|
|
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "concat size" begin
|
|
|
|
|
input = randn(10, 2)
|
|
|
|
|
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
|
|
|
|
|
end
|
|
|
|
|
end
|
2019-12-06 04:54:25 +00:00
|
|
|
|
|
|
|
|
|
@testset "output dimensions" begin
|
|
|
|
|
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
|
|
|
|
@test Flux.outdims(m, (10, 10)) == (6, 6)
|
|
|
|
|
|
|
|
|
|
m = Dense(10, 5)
|
|
|
|
|
@test Flux.outdims(m, (5, 2)) == (5,)
|
|
|
|
|
@test Flux.outdims(m, (10,)) == (5,)
|
|
|
|
|
|
|
|
|
|
m = Flux.Diagonal(10)
|
|
|
|
|
@test Flux.outdims(m, (10,)) == (10,)
|
|
|
|
|
|
|
|
|
|
m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
|
|
|
|
|
@test Flux.outdims(m, (10, 10)) == (8, 8)
|
|
|
|
|
end
|
2018-08-23 13:34:11 +00:00
|
|
|
|
end
|