Merge pull request #153 from boathit/master

Registering backward function for logsoftmax
This commit is contained in:
Mike J Innes 2018-01-23 13:54:55 +00:00 committed by GitHub
commit ed8d026723
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 2 deletions

View File

@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D,
param, params, mapleaves
using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl")

View File

@ -130,12 +130,16 @@ end
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, conv2d, pool
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)

View File

@ -18,6 +18,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
## uncomment the following test when logsoftmax has been added into NNlib.jl
#@test gradtest(x -> logsoftmax(x).*(1:3), 3)
#@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))