simpler types
This commit is contained in:
parent
0d568f6faf
commit
5de3e9f2b2
@ -1,19 +1,19 @@
|
|||||||
mutable struct Softmax{T,N,A,B} <: AbstractArray{T,N}
|
mutable struct Softmax{T,N,A} <: AbstractArray{T,N}
|
||||||
logits::A
|
logits::A
|
||||||
probs::B
|
probs::A
|
||||||
Softmax{T,N,A,B}(logits::A) where {T,N,A,B} = new(logits)
|
Softmax{T,N,A}(logits::A) where {T,N,A} = new(logits)
|
||||||
end
|
end
|
||||||
|
|
||||||
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
|
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
|
||||||
Softmax{eltype(logits),ndims(logits),typeof(logits),typeof(Tracker.data(logits))}(logits)
|
Softmax{eltype(logits),ndims(logits),typeof(logits)}(logits)
|
||||||
|
|
||||||
@forward Softmax.logits Base.size
|
@forward Softmax.logits Base.size
|
||||||
|
|
||||||
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
|
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
|
||||||
|
|
||||||
function Base.getindex(s::Softmax, i)
|
function Base.getindex(s::Softmax, i)
|
||||||
isdefined(s, :probs) || (s.probs = NNlib.softmax(Tracker.data(s.logits)))
|
isdefined(s, :probs) || (s.probs = NNlib.softmax(s.logits))
|
||||||
s.probs[i]
|
Tracker.data(s.probs)[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)
|
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)
|
||||||
|
Loading…
Reference in New Issue
Block a user