more activation functions
This commit is contained in:
parent
526165c897
commit
29aab1e4e0
@ -1,5 +1,11 @@
|
|||||||
export σ
|
export σ, relu, softmax
|
||||||
|
|
||||||
σ(x) = 1 ./ (1 .+ exp.(-x))
|
σ(x) = 1 ./ (1 .+ exp.(-x))
|
||||||
|
|
||||||
back!(::typeof(σ), Δ, x) = Δ .* σ(x)./(1.-σ(x))
|
back!(::typeof(σ), Δ, x) = Δ .* σ(x)./(1.-σ(x))
|
||||||
|
|
||||||
|
relu(x) = max(0, x)
|
||||||
|
|
||||||
|
back(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
||||||
|
|
||||||
|
softmax(x) = error("not implemented")
|
||||||
|
@ -30,3 +30,5 @@ end
|
|||||||
node(::typeof(*), args...) = mx.dot(args...)
|
node(::typeof(*), args...) = mx.dot(args...)
|
||||||
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||||
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||||
|
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
||||||
|
node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||||
|
Loading…
Reference in New Issue
Block a user