Merge pull request #710 from johnnychen94/master

naive implementation of activations
This commit is contained in:
Mike J Innes 2019-04-05 15:33:31 +01:00 committed by GitHub
commit 54d9229be9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 53 deletions

View File

@ -40,7 +40,24 @@ function Base.show(io::IO, c::Chain)
print(io, ")")
end
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
# This is a temporary and naive implementation
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
# Johnny Chen -- @johnnychen94
"""
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
rst = []
for l in c
x = get(rst, length(rst), input)
push!(rst, l(x))
end
return rst
end
"""
Dense(in::Integer, out::Integer, σ = identity)

View File

@ -1,6 +1,18 @@
using Test, Random
import Flux: activations
@testset "basic" begin
@testset "helpers" begin
@testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10)
@test activations(Chain(), x) == []
@test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
end
end
@testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))