recursive way of doing activations
This commit is contained in:
parent
f41219133e
commit
46abfbbd5c
@ -51,16 +51,17 @@ end
|
|||||||
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
Calculate the forward results of each layers in Chain `c` with `input` as model input.
|
||||||
"""
|
"""
|
||||||
function activations(c::Chain, input)
|
function activations(c::Chain, input)
|
||||||
res = Zygote.Buffer([], length(c))
|
extraChain(c.layers, input)
|
||||||
if length(c) > 0
|
|
||||||
res[1] = c[1](input)
|
|
||||||
for (i,l) in enumerate(c[2:end])
|
|
||||||
res[i+1] = l(res[i])
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return copy(res)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function extraChain(fs::Tuple, x)
|
||||||
|
res = first(fs)(x)
|
||||||
|
return (res, extraChain(Base.tail(fs), res)...)
|
||||||
|
end
|
||||||
|
|
||||||
|
extraChain(::Tuple{}, x) = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dense(in::Integer, out::Integer, σ = identity)
|
Dense(in::Integer, out::Integer, σ = identity)
|
||||||
|
Loading…
Reference in New Issue
Block a user