add x into results
This commit is contained in:
parent
c4ebd199db
commit
13c58494ec
@ -42,18 +42,14 @@ end
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
activations(c::Union{Chain, Any}, x)
|
activations(c::Chain, x)
|
||||||
Calculate the forward results of each layers in Chain `c`
|
Calculate the forward results of each layers in Chain `c`
|
||||||
"""
|
"""
|
||||||
activations(m, x) = activations(Chain(m), x)
|
|
||||||
function activations(c::Chain, x)
|
function activations(c::Chain, x)
|
||||||
if length(c) == 0
|
if isempty(c)
|
||||||
return []
|
return [x, ]
|
||||||
end
|
|
||||||
rst = [c[1](x), ]
|
|
||||||
if length(c) == 1
|
|
||||||
return rst
|
|
||||||
end
|
end
|
||||||
|
rst = [x, c[1](x)]
|
||||||
for l in c[2:end]
|
for l in c[2:end]
|
||||||
push!(rst, l(rst[end]))
|
push!(rst, l(rst[end]))
|
||||||
end
|
end
|
||||||
|
@ -6,16 +6,12 @@ import Flux: activations
|
|||||||
@testset "activations" begin
|
@testset "activations" begin
|
||||||
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||||
x = rand(10)
|
x = rand(10)
|
||||||
@test_nowarn activations(dummy_model[1], x)
|
@test activations(Chain(), x) == [x, ]
|
||||||
@test_nowarn activations(dummy_model[1:end-1], x)
|
@test activations(dummy_model, x)[2] == dummy_model[1](x)
|
||||||
@test_nowarn activations(dummy_model, x)
|
@test activations(dummy_model, x)[3] == x |> dummy_model[1] |> dummy_model[2]
|
||||||
|
|
||||||
@test activations(Chain(), x) == []
|
|
||||||
@test activations(dummy_model, x)[1] == dummy_model[1](x)
|
|
||||||
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Chain" begin
|
@testset "Chain" begin
|
||||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||||
|
Loading…
Reference in New Issue
Block a user