From 3d41dca33871ee1a25e443bfe47d2e5f291091b9 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 16 Nov 2018 12:22:15 +0000 Subject: [PATCH] immutable chain --- src/Flux.jl | 1 + src/layers/basic.jl | 13 ++++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 48847fbe..da040aa0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -2,6 +2,7 @@ module Flux # Zero Flux Given +using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward diff --git a/src/layers/basic.jl b/src/layers/basic.jl index c0188bf2..fddd4fc9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -16,18 +16,21 @@ m(x) == m[2](m[1](x)) `Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`. `m[1:3](x)` will calculate the output of the first three layers. """ -struct Chain - layers::Vector{Any} - Chain(xs...) = new([xs...]) +struct Chain{T<:Tuple} + layers::T + Chain(xs...) = new{typeof(xs)}(xs) end -@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push! +@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex @forward Chain.layers Base.iterate children(c::Chain) = c.layers mapchildren(f, c::Chain) = Chain(f.(c.layers)...) -(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x) +applychain(::Tuple{}, x) = x +applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) + +(c::Chain)(x) = applychain(c.layers, x) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)