simpler types

This commit is contained in:
Mike J Innes 2017-10-18 10:10:24 +01:00
parent 0d568f6faf
commit 5de3e9f2b2
1 changed files with 6 additions and 6 deletions

View File

@ -1,19 +1,19 @@
mutable struct Softmax{T,N,A,B} <: AbstractArray{T,N}
mutable struct Softmax{T,N,A} <: AbstractArray{T,N}
logits::A
probs::B
Softmax{T,N,A,B}(logits::A) where {T,N,A,B} = new(logits)
probs::A
Softmax{T,N,A}(logits::A) where {T,N,A} = new(logits)
end
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
Softmax{eltype(logits),ndims(logits),typeof(logits),typeof(Tracker.data(logits))}(logits)
Softmax{eltype(logits),ndims(logits),typeof(logits)}(logits)
@forward Softmax.logits Base.size
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
function Base.getindex(s::Softmax, i)
isdefined(s, :probs) || (s.probs = NNlib.softmax(Tracker.data(s.logits)))
s.probs[i]
isdefined(s, :probs) || (s.probs = NNlib.softmax(s.logits))
Tracker.data(s.probs)[i]
end
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)