simpler test

This commit is contained in:
dsweber2 2019-11-14 14:05:53 -08:00
parent 0fe3ac4e77
commit 58c794702d
2 changed files with 9 additions and 7 deletions

View File

@ -31,7 +31,7 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
(c::Chain)(x) = applychain(c.layers, x)
(c::Chain)(x, i) = extraChain(c.layers, x)[i]
(c::Chain)(x) = extraChain(c.layers, x)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
@ -60,7 +60,7 @@ function extraChain(fs::Tuple, x)
return (res, extraChain(Base.tail(fs), res)...)
end
extraChain(::Tuple{}, x) = []
extraChain(::Tuple{}, x) = ()

View File

@ -4,11 +4,13 @@ import Flux: activations
@testset "basic" begin
@testset "helpers" begin
@testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10)
@test activations(Chain(), x) == []
@test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x))
x = randn(10)
@test activations(dummy_model, x)[1] == x.^2
@test activations(dummy_model, x)[2] == (x.^2 .- 3)
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3)
@test activations(Chain(), x) == ()
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
end
end