make activations zygote friendly

This commit is contained in:
dsweber2 2019-09-10 00:54:49 -07:00
parent 9d05afaccc
commit cdaaca8cfa

View File

@ -44,17 +44,15 @@ end
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
# Johnny Chen -- @johnnychen94
# only slightly changed to better handle interaction with Zygote @dsweber2
"""
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
rst = []
for l in c
x = get(rst, length(rst), input)
push!(rst, l(x))
end
return rst
buffed = accumulate!((x,y)->y(x), Zygote.Buffer([], length(c)),
[l for l in c], dims=1, init=input)
return copy(buffed)
end