add x into results

This commit is contained in:
JohnnyChen 2019-03-28 19:28:59 +08:00
parent c4ebd199db
commit 13c58494ec
2 changed files with 8 additions and 16 deletions

View File

@ -42,18 +42,14 @@ end
""" """
activations(c::Union{Chain, Any}, x) activations(c::Chain, x)
Calculate the forward results of each layers in Chain `c` Calculate the forward results of each layers in Chain `c`
""" """
activations(m, x) = activations(Chain(m), x)
function activations(c::Chain, x) function activations(c::Chain, x)
if length(c) == 0 if isempty(c)
return [] return [x, ]
end
rst = [c[1](x), ]
if length(c) == 1
return rst
end end
rst = [x, c[1](x)]
for l in c[2:end] for l in c[2:end]
push!(rst, l(rst[end])) push!(rst, l(rst[end]))
end end

View File

@ -6,16 +6,12 @@ import Flux: activations
@testset "activations" begin @testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax) dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10) x = rand(10)
@test_nowarn activations(dummy_model[1], x) @test activations(Chain(), x) == [x, ]
@test_nowarn activations(dummy_model[1:end-1], x) @test activations(dummy_model, x)[2] == dummy_model[1](x)
@test_nowarn activations(dummy_model, x) @test activations(dummy_model, x)[3] == x |> dummy_model[1] |> dummy_model[2]
@test activations(Chain(), x) == []
@test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
end end
end end
@testset "Chain" begin @testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10)) @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))