Compare commits
3 Commits
Author | SHA1 | Date |
---|---|---|
![]() |
5de3e9f2b2 | |
![]() |
0d568f6faf | |
![]() |
2854ae101b |
|
@ -23,6 +23,7 @@ include("utils.jl")
|
|||
include("onehot.jl")
|
||||
include("tree.jl")
|
||||
|
||||
include("layers/softmax.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
include("layers/recurrent.jl")
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
mutable struct Softmax{T,N,A} <: AbstractArray{T,N}
|
||||
logits::A
|
||||
probs::A
|
||||
Softmax{T,N,A}(logits::A) where {T,N,A} = new(logits)
|
||||
end
|
||||
|
||||
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
|
||||
Softmax{eltype(logits),ndims(logits),typeof(logits)}(logits)
|
||||
|
||||
@forward Softmax.logits Base.size
|
||||
|
||||
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
|
||||
|
||||
function Base.getindex(s::Softmax, i)
|
||||
isdefined(s, :probs) || (s.probs = NNlib.softmax(s.logits))
|
||||
Tracker.data(s.probs)[i]
|
||||
end
|
||||
|
||||
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)
|
||||
|
||||
softmax(xs::AbstractVecOrMat{<:Real}) = softmax(convert.(AbstractFloat, xs))
|
||||
|
||||
softmax(xs::TrackedArray) = TrackedArray(Tracker.Call(NNlib.softmax, xs), Softmax(xs))
|
|
@ -12,3 +12,6 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
|||
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
|
||||
-sum(y .* ypred) / size(y, 2)
|
||||
end
|
||||
|
||||
crossentropy(ŷ::Union{Softmax,TrackedArray{<:Softmax}}, y::AbstractVecOrMat) =
|
||||
logitcrossentropy(Tracker.data(ŷ).logits, y)
|
||||
|
|
|
@ -11,8 +11,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
|
||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
@test gradtest(x -> NNlib.softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> NNlib.softmax(x).*(1:3), (3,5))
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
|
Loading…
Reference in New Issue