diff --git a/src/layers/activation.jl b/src/layers/activation.jl index 0ed8bb38..e7d9cc8d 100644 --- a/src/layers/activation.jl +++ b/src/layers/activation.jl @@ -16,6 +16,7 @@ flatten(xs) = reshape(xs, length(xs)) infer(::typeof(softmax), x) = x infer(::typeof(tanh), x) = x +infer(::typeof(relu), x) = x infer(::typeof(σ), x) = x infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...)