Compare commits

...

3 Commits

Author SHA1 Message Date
Mike J Innes 5de3e9f2b2 simpler types 2017-10-18 10:23:42 +01:00
Mike J Innes 0d568f6faf fix tests 2017-10-18 10:23:42 +01:00
Mike J Innes 2854ae101b lazy softmax 2017-10-18 10:23:42 +01:00
4 changed files with 29 additions and 2 deletions

View File

@ -23,6 +23,7 @@ include("utils.jl")
include("onehot.jl")
include("tree.jl")
include("layers/softmax.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/recurrent.jl")

23
src/layers/softmax.jl Normal file
View File

@ -0,0 +1,23 @@
mutable struct Softmax{T,N,A} <: AbstractArray{T,N}
logits::A
probs::A
Softmax{T,N,A}(logits::A) where {T,N,A} = new(logits)
end
Softmax(logits::AbstractVecOrMat{<:AbstractFloat}) =
Softmax{eltype(logits),ndims(logits),typeof(logits)}(logits)
@forward Softmax.logits Base.size
Base.IndexStyle(::Type{Softmax{T,N,A}}) where {T,N,A} = IndexStyle(A)
function Base.getindex(s::Softmax, i)
isdefined(s, :probs) || (s.probs = NNlib.softmax(s.logits))
Tracker.data(s.probs)[i]
end
softmax(xs::AbstractVecOrMat{<:AbstractFloat}) = Softmax(xs)
softmax(xs::AbstractVecOrMat{<:Real}) = softmax(convert.(AbstractFloat, xs))
softmax(xs::TrackedArray) = TrackedArray(Tracker.Call(NNlib.softmax, xs), Softmax(xs))

View File

@ -12,3 +12,6 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
end
crossentropy(::Union{Softmax,TrackedArray{<:Softmax}}, y::AbstractVecOrMat) =
logitcrossentropy(Tracker.data().logits, y)

View File

@ -11,8 +11,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@test gradtest(x -> NNlib.softmax(x).*(1:3), 3)
@test gradtest(x -> NNlib.softmax(x).*(1:3), (3,5))
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))