diff --git a/src/Flux.jl b/src/Flux.jl index 1b4cbbc7..a48b7b90 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -20,9 +20,9 @@ include("optimise/Optimise.jl") using .Optimise include("utils.jl") +include("onehot.jl") include("compiler/Compiler.jl") -using .Compiler: @net include("layers/stateless.jl") include("layers/basic.jl") diff --git a/src/onehot.jl b/src/onehot.jl new file mode 100644 index 00000000..392fc879 --- /dev/null +++ b/src/onehot.jl @@ -0,0 +1,31 @@ +struct OneHotVector <: AbstractVector{Bool} + ix::UInt32 + of::UInt32 +end + +Base.size(xs::OneHotVector) = (Int64(xs.of),) + +Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix + +Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] + +struct OneHotMatrix <: AbstractMatrix{Bool} + data::Vector{OneHotVector} +end + +Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data)) + +Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] + +Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] + +Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) + +onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels)) +onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls]) + +onecold(y::AbstractVector, labels = 1:length(y)) = + labels[findfirst(y, maximum(y))] + +onecold(y::AbstractMatrix, l...) = + squeeze(mapslices(y -> onecold(y, l...), y, 1), 1) diff --git a/src/utils.jl b/src/utils.jl index 648fcff9..61670223 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,32 +2,13 @@ initn(dims...) = randn(dims...)/100 -""" - onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false] - - onehot(Float32, 'c', ['a', 'b', 'c', 'd']) => [0., 0., 1., 0.] - -Produce a one-hot-encoded version of an item, given a list of possible values -for the item. -""" -onehot(T::Type, label, labels) = T[i == label for i in labels] -onehot(label, labels) = onehot(Int, label, labels) - -""" - onecold([0.0, 1.0, 0.0, ...], - ['a', 'b', 'c', ...]) => 'b' - -The inverse of `onehot`; takes an output prediction vector and a list of -possible values, and produces the appropriate value. -""" -onecold(y::AbstractVector, labels = 1:length(y)) = - labels[findfirst(y, maximum(y))] - -onecold(y::AbstractMatrix, l...) = - squeeze(mapslices(y -> onecold(y, l...), y, 1), 1) - flatten(xs) = reshape(xs, size(xs, 1), :) +unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) + +stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) +unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] + # Other function accuracy(m, data) diff --git a/test/compiler.jl b/test/compiler.jl index e550b14f..a82550e8 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,5 +1,6 @@ using DataFlow, MacroTools -using Flux.Compiler: @net, graph, stack, squeeze, unsqueeze +using Flux: stack, unsqueeze +using Flux.Compiler: @net, graph using DataFlow: Line, Frame @net type Affine @@ -79,7 +80,7 @@ end _, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,)) @test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by) ru = Flux.Compiler.unroll(r, 3) - ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys) + ru(unsqueeze(stack(squeeze.(xs, 1), 1), 1))[1] == squeeze.(ys, 1) end end