diff --git a/src/activation.jl b/src/activation.jl index 398fe0fe..37eddb8f 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -1,5 +1,11 @@ -export σ +export σ, relu, softmax σ(x) = 1 ./ (1 .+ exp.(-x)) back!(::typeof(σ), Δ, x) = Δ .* σ(x)./(1.-σ(x)) + +relu(x) = max(0, x) + +back(::typeof(relu), Δ, x) = Δ .* (x .< 0) + +softmax(x) = error("not implemented") diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 832b4c86..57b3fc09 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -30,3 +30,5 @@ end node(::typeof(*), args...) = mx.dot(args...) node(::typeof(+), args...) = mx.broadcast_plus(args...) node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) +node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu) +node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))