Registering backward function for logsoftmax

This commit is contained in:
boathit 2018-01-21 15:20:59 +08:00
parent 72eabde373
commit 374d7a5f1e
3 changed files with 10 additions and 2 deletions

View File

@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D,
param, params, mapleaves
using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl")

View File

@ -130,12 +130,16 @@ end
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, conv2d, pool
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)

View File

@ -18,6 +18,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
## uncomment the following test when logsoftmax has been added into NNlib.jl
#@test gradtest(x -> logsoftmax(x).*(1:3), 3)
#@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))