lazy softmax

This commit is contained in:
Mike J Innes 2017-10-17 19:40:58 +01:00
parent b26f77489e
commit 2854ae101b
3 changed files with 27 additions and 0 deletions

View File

@ -23,6 +23,7 @@ include("utils.jl")
include("onehot.jl")
include("tree.jl")
include("layers/softmax.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/recurrent.jl")

23
src/layers/softmax.jl Normal file
View File

@ -0,0 +1,23 @@
mutable struct Softmax{T,N,A,B} <: AbstractArray{T,N}
logits::A
probs::B
Softmax{T,N,A,B}(logits::A) where {T,N,A,B} = new(logits)
end
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
Softmax{eltype(logits),ndims(logits),typeof(logits),typeof(Tracker.data(logits))}(logits)
@forward Softmax.logits Base.size
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
function Base.getindex(s::Softmax, i)
isdefined(s, :probs) || (s.probs = NNlib.softmax(Tracker.data(s.logits)))
s.probs[i]
end
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)
softmax(xs::AbstractVecOrMat{<:Real}) = softmax(convert.(AbstractFloat, xs))
softmax(xs::TrackedArray) = TrackedArray(Tracker.Call(NNlib.softmax, xs), Softmax(xs))

View File

@ -12,3 +12,6 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
end
crossentropy(::Union{Softmax,TrackedArray{<:Softmax}}, y::AbstractVecOrMat) =
logitcrossentropy(Tracker.data().logits, y)