recursive way of doing activations

This commit is contained in:
dsweber2 2019-09-11 17:36:37 -07:00
parent f41219133e
commit 46abfbbd5c

View File

@ -51,16 +51,17 @@ end
Calculate the forward results of each layers in Chain `c` with `input` as model input. Calculate the forward results of each layers in Chain `c` with `input` as model input.
""" """
function activations(c::Chain, input) function activations(c::Chain, input)
res = Zygote.Buffer([], length(c)) extraChain(c.layers, input)
if length(c) > 0
res[1] = c[1](input)
for (i,l) in enumerate(c[2:end])
res[i+1] = l(res[i])
end
end
return copy(res)
end end
function extraChain(fs::Tuple, x)
res = first(fs)(x)
return (res, extraChain(Base.tail(fs), res)...)
end
extraChain(::Tuple{}, x) = []
""" """
Dense(in::Integer, out::Integer, σ = identity) Dense(in::Integer, out::Integer, σ = identity)