correct the function behavior; support Any type

This commit is contained in:
JohnnyChen 2019-04-05 18:16:44 +08:00
parent 82595648e2
commit de7a5f4024
2 changed files with 10 additions and 5 deletions

View File

@ -41,15 +41,18 @@ function Base.show(io::IO, c::Chain)
end end
# This is a temporary and naive inplementation
# see issue https://github.com/FluxML/Flux.jl/issues/702
""" """
activations(c::Chain, x) activations(c::Chain, x)
Calculate the forward results of each layers in Chain `c` Calculate the forward results of each layers in Chain `c`
""" """
function activations(c::Chain, x) function activations(c::Chain, x)
rst = Array{Any,1}()
if isempty(c) if isempty(c)
return [x, ] return rst
end end
rst = [x, c[1](x)] push!(rst, c[1](x))
for l in c[2:end] for l in c[2:end]
push!(rst, l(rst[end])) push!(rst, l(rst[end]))
end end

View File

@ -6,9 +6,11 @@ import Flux: activations
@testset "activations" begin @testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax) dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10) x = rand(10)
@test activations(Chain(), x) == [x, ] @test activations(Chain(), x) == []
@test activations(dummy_model, x)[2] == dummy_model[1](x) @test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[3] == x |> dummy_model[1] |> dummy_model[2] @test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
@test_nowarn activations(Chain(identity, x->:foo), 1) # different types
end end
end end