diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f42a9619..e8dde1a3 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -44,17 +44,15 @@ end # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 # Johnny Chen -- @johnnychen94 +# only slightly changed to better handle interaction with Zygote @dsweber2 """ activations(c::Chain, input) Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - rst = [] - for l in c - x = get(rst, length(rst), input) - push!(rst, l(x)) - end - return rst + buffed = accumulate!((x,y)->y(x), Zygote.Buffer([], length(c)), + [l for l in c], dims=1, init=input) + return copy(buffed) end