simpler types
This commit is contained in:
parent
0d568f6faf
commit
5de3e9f2b2
|
@ -1,19 +1,19 @@
|
|||
mutable struct Softmax{T,N,A,B} <: AbstractArray{T,N}
|
||||
mutable struct Softmax{T,N,A} <: AbstractArray{T,N}
|
||||
logits::A
|
||||
probs::B
|
||||
Softmax{T,N,A,B}(logits::A) where {T,N,A,B} = new(logits)
|
||||
probs::A
|
||||
Softmax{T,N,A}(logits::A) where {T,N,A} = new(logits)
|
||||
end
|
||||
|
||||
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
|
||||
Softmax{eltype(logits),ndims(logits),typeof(logits),typeof(Tracker.data(logits))}(logits)
|
||||
Softmax{eltype(logits),ndims(logits),typeof(logits)}(logits)
|
||||
|
||||
@forward Softmax.logits Base.size
|
||||
|
||||
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
|
||||
|
||||
function Base.getindex(s::Softmax, i)
|
||||
isdefined(s, :probs) || (s.probs = NNlib.softmax(Tracker.data(s.logits)))
|
||||
s.probs[i]
|
||||
isdefined(s, :probs) || (s.probs = NNlib.softmax(s.logits))
|
||||
Tracker.data(s.probs)[i]
|
||||
end
|
||||
|
||||
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)
|
||||
|
|
Loading…
Reference in New Issue