naive implementation of activations
This commit is contained in:
parent
b5a6207350
commit
ccfe0f8720
|
@ -40,7 +40,23 @@ function Base.show(io::IO, c::Chain)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
||||
|
||||
"""
|
||||
activations(c::Union{Chain, Any}, x)
|
||||
Calculate the forward results of each layers in Chain `c`
|
||||
"""
|
||||
activations(m, x) = activations(Chain(m), x)
|
||||
function activations(c::Chain, x)
|
||||
rst = [c[1](x), ]
|
||||
if length(c) == 1
|
||||
return rst
|
||||
end
|
||||
for l in c[2:end]
|
||||
push!(rst, l(rst[end]))
|
||||
end
|
||||
return rst
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
|
|
@ -1,4 +1,18 @@
|
|||
using Test, Random
|
||||
import Flux: activations
|
||||
|
||||
@testset "helpers" begin
|
||||
@testset "activations" begin
|
||||
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||
x = rand(10)
|
||||
@test_nowarn activations(dummy_model[1], x)
|
||||
@test_nowarn activations(dummy_model[1:end-1], x)
|
||||
@test_nowarn activations(dummy_model, x)
|
||||
|
||||
@test activations(dummy_model, x)[1] == dummy_model[1](x)
|
||||
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
|
||||
end
|
||||
end
|
||||
|
||||
@testset "basic" begin
|
||||
@testset "Chain" begin
|
||||
|
|
Loading…
Reference in New Issue