lazy softmax
This commit is contained in:
parent
b26f77489e
commit
2854ae101b
@ -23,6 +23,7 @@ include("utils.jl")
|
|||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
include("tree.jl")
|
include("tree.jl")
|
||||||
|
|
||||||
|
include("layers/softmax.jl")
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
|
23
src/layers/softmax.jl
Normal file
23
src/layers/softmax.jl
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
mutable struct Softmax{T,N,A,B} <: AbstractArray{T,N}
|
||||||
|
logits::A
|
||||||
|
probs::B
|
||||||
|
Softmax{T,N,A,B}(logits::A) where {T,N,A,B} = new(logits)
|
||||||
|
end
|
||||||
|
|
||||||
|
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
|
||||||
|
Softmax{eltype(logits),ndims(logits),typeof(logits),typeof(Tracker.data(logits))}(logits)
|
||||||
|
|
||||||
|
@forward Softmax.logits Base.size
|
||||||
|
|
||||||
|
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
|
||||||
|
|
||||||
|
function Base.getindex(s::Softmax, i)
|
||||||
|
isdefined(s, :probs) || (s.probs = NNlib.softmax(Tracker.data(s.logits)))
|
||||||
|
s.probs[i]
|
||||||
|
end
|
||||||
|
|
||||||
|
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)
|
||||||
|
|
||||||
|
softmax(xs::AbstractVecOrMat{<:Real}) = softmax(convert.(AbstractFloat, xs))
|
||||||
|
|
||||||
|
softmax(xs::TrackedArray) = TrackedArray(Tracker.Call(NNlib.softmax, xs), Softmax(xs))
|
@ -12,3 +12,6 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
|||||||
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
|
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
|
||||||
-sum(y .* ypred) / size(y, 2)
|
-sum(y .* ypred) / size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
crossentropy(ŷ::Union{Softmax,TrackedArray{<:Softmax}}, y::AbstractVecOrMat) =
|
||||||
|
logitcrossentropy(Tracker.data(ŷ).logits, y)
|
||||||
|
Loading…
Reference in New Issue
Block a user