diff --git a/src/Flux.jl b/src/Flux.jl index 75d2b2b3..4aa4d8e9 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D, param, params, mapleaves using NNlib -export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax conv2d, maxpool2d, avgpool2d include("tracker/Tracker.jl") diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index c8bb03b1..b8de5be1 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -130,12 +130,16 @@ end # NNlib using NNlib -import NNlib: softmax, ∇softmax, conv2d, pool +import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) +logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs)) + +back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) + # TODO: can store kwargs efficiently in namedtuples _conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad) diff --git a/test/tracker.jl b/test/tracker.jl index 90eb0af1..ca6f2cc7 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -18,6 +18,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), (3,5)) +## uncomment the following test when logsoftmax has been added into NNlib.jl +#@test gradtest(x -> logsoftmax(x).*(1:3), 3) +#@test gradtest(x -> logsoftmax(x).*(1:3), (3,5)) + @test gradtest(Flux.mse, rand(5,5), rand(5, 5)) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))