Registering backward function for logsoftmax
This commit is contained in:
parent
72eabde373
commit
374d7a5f1e
|
@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D,
|
|||
param, params, mapleaves
|
||||
|
||||
using NNlib
|
||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
|
||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax
|
||||
conv2d, maxpool2d, avgpool2d
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
|
|
|
@ -130,12 +130,16 @@ end
|
|||
# NNlib
|
||||
|
||||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, conv2d, pool
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool
|
||||
|
||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||
|
||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||
|
||||
logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
|
||||
|
||||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||
|
||||
# TODO: can store kwargs efficiently in namedtuples
|
||||
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
|
||||
|
||||
|
|
|
@ -18,6 +18,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
||||
## uncomment the following test when logsoftmax has been added into NNlib.jl
|
||||
#@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
||||
#@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
|
|
Loading…
Reference in New Issue