diff --git a/src/Flux.jl b/src/Flux.jl index d32c8194..a4d0f174 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -23,6 +23,7 @@ include("utils.jl") include("onehot.jl") include("tree.jl") +include("layers/softmax.jl") include("layers/stateless.jl") include("layers/basic.jl") include("layers/recurrent.jl") diff --git a/src/layers/softmax.jl b/src/layers/softmax.jl new file mode 100644 index 00000000..d736597a --- /dev/null +++ b/src/layers/softmax.jl @@ -0,0 +1,23 @@ +mutable struct Softmax{T,N,A,B} <: AbstractArray{T,N} + logits::A + probs::B + Softmax{T,N,A,B}(logits::A) where {T,N,A,B} = new(logits) +end + +Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) = + Softmax{eltype(logits),ndims(logits),typeof(logits),typeof(Tracker.data(logits))}(logits) + +@forward Softmax.logits Base.size + +Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A) + +function Base.getindex(s::Softmax, i) + isdefined(s, :probs) || (s.probs = NNlib.softmax(Tracker.data(s.logits))) + s.probs[i] +end + +softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs) + +softmax(xs::AbstractVecOrMat{<:Real}) = softmax(convert.(AbstractFloat, xs)) + +softmax(xs::TrackedArray) = TrackedArray(Tracker.Call(NNlib.softmax, xs), Softmax(xs)) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 3931c216..655359df 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -12,3 +12,6 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) ypred = logŷ .- log.(sum(exp.(logŷ), 1)) -sum(y .* ypred) / size(y, 2) end + +crossentropy(ŷ::Union{Softmax,TrackedArray{<:Softmax}}, y::AbstractVecOrMat) = + logitcrossentropy(Tracker.data(ŷ).logits, y)