diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 005915bb..76704d0c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -40,7 +40,24 @@ function Base.show(io::IO, c::Chain) print(io, ")") end -activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x) + +# This is a temporary and naive implementation +# it might be replaced in the future for better performance +# see issue https://github.com/FluxML/Flux.jl/issues/702 +# Johnny Chen -- @johnnychen94 +""" + activations(c::Chain, input) +Calculate the forward results of each layers in Chain `c` with `input` as model input. +""" +function activations(c::Chain, input) + rst = [] + for l in c + x = get(rst, length(rst), input) + push!(rst, l(x)) + end + return rst +end + """ Dense(in::Integer, out::Integer, σ = identity) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 3c5229f4..4deb545f 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -1,63 +1,75 @@ using Test, Random +import Flux: activations @testset "basic" begin - @testset "Chain" begin - @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10)) - @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) - # numeric test should be put into testset of corresponding layer + @testset "helpers" begin + @testset "activations" begin + dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax) + x = rand(10) + @test activations(Chain(), x) == [] + @test activations(dummy_model, x)[1] == dummy_model[1](x) + @test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2] + @test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type + end + end + + @testset "Chain" begin + @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10)) + @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) + # numeric test should be put into testset of corresponding layer + end + + @testset "Dense" begin + @test length(Dense(10, 5)(randn(10))) == 5 + @test_throws DimensionMismatch Dense(10, 5)(randn(1)) + @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting + @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting + + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1) + @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2) + @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) + @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + + end + + @testset "Diagonal" begin + @test length(Flux.Diagonal(10)(randn(10))) == 10 + @test length(Flux.Diagonal(10)(1)) == 10 + @test length(Flux.Diagonal(10)(randn(1))) == 10 + @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) + + @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] + @test Flux.Diagonal(2)([1,2]) == [1,2] + @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] + end + + @testset "Maxout" begin + # Note that the normal common usage of Maxout is as per the docstring + # These are abnormal constructors used for testing purposes + + @testset "Constructor" begin + mo = Maxout(() -> identity, 4) + input = rand(40) + @test mo(input) == input end - @testset "Dense" begin - @test length(Dense(10, 5)(randn(10))) == 5 - @test_throws DimensionMismatch Dense(10, 5)(randn(1)) - @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting - @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting - - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1) - @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2) - @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1) - @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] - + @testset "simple alternatives" begin + mo = Maxout((x -> x, x -> 2x, x -> 0.5x)) + input = rand(40) + @test mo(input) == 2*input end - @testset "Diagonal" begin - @test length(Flux.Diagonal(10)(randn(10))) == 10 - @test length(Flux.Diagonal(10)(1)) == 10 - @test length(Flux.Diagonal(10)(randn(1))) == 10 - @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) - - @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] - @test Flux.Diagonal(2)([1,2]) == [1,2] - @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] + @testset "complex alternatives" begin + mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)) + input = [3.0 2.0] + target = [0.5, 0.7].*input + @test mo(input) == target end - @testset "Maxout" begin - # Note that the normal common usage of Maxout is as per the docstring - # These are abnormal constructors used for testing purposes - - @testset "Constructor" begin - mo = Maxout(() -> identity, 4) - input = rand(40) - @test mo(input) == input - end - - @testset "simple alternatives" begin - mo = Maxout((x -> x, x -> 2x, x -> 0.5x)) - input = rand(40) - @test mo(input) == 2*input - end - - @testset "complex alternatives" begin - mo = Maxout((x -> [0.5; 0.1]*x, x -> [0.2; 0.7]*x)) - input = [3.0 2.0] - target = [0.5, 0.7].*input - @test mo(input) == target - end - - @testset "params" begin - mo = Maxout(()->Dense(32, 64), 4) - ps = params(mo) - @test length(ps) == 8 #4 alts, each with weight and bias - end + @testset "params" begin + mo = Maxout(()->Dense(32, 64), 4) + ps = params(mo) + @test length(ps) == 8 #4 alts, each with weight and bias end + end end