onehot
This commit is contained in:
parent
d7e3f7d6e1
commit
1855a37319
@ -20,9 +20,9 @@ include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
||||
include("compiler/Compiler.jl")
|
||||
using .Compiler: @net
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
|
31
src/onehot.jl
Normal file
31
src/onehot.jl
Normal file
@ -0,0 +1,31 @@
|
||||
struct OneHotVector <: AbstractVector{Bool}
|
||||
ix::UInt32
|
||||
of::UInt32
|
||||
end
|
||||
|
||||
Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
||||
|
||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
|
||||
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix <: AbstractMatrix{Bool}
|
||||
data::Vector{OneHotVector}
|
||||
end
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||
|
||||
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||
|
||||
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
||||
onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls])
|
||||
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[findfirst(y, maximum(y))]
|
||||
|
||||
onecold(y::AbstractMatrix, l...) =
|
||||
squeeze(mapslices(y -> onecold(y, l...), y, 1), 1)
|
29
src/utils.jl
29
src/utils.jl
@ -2,32 +2,13 @@
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
|
||||
"""
|
||||
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
|
||||
|
||||
onehot(Float32, 'c', ['a', 'b', 'c', 'd']) => [0., 0., 1., 0.]
|
||||
|
||||
Produce a one-hot-encoded version of an item, given a list of possible values
|
||||
for the item.
|
||||
"""
|
||||
onehot(T::Type, label, labels) = T[i == label for i in labels]
|
||||
onehot(label, labels) = onehot(Int, label, labels)
|
||||
|
||||
"""
|
||||
onecold([0.0, 1.0, 0.0, ...],
|
||||
['a', 'b', 'c', ...]) => 'b'
|
||||
|
||||
The inverse of `onehot`; takes an output prediction vector and a list of
|
||||
possible values, and produces the appropriate value.
|
||||
"""
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[findfirst(y, maximum(y))]
|
||||
|
||||
onecold(y::AbstractMatrix, l...) =
|
||||
squeeze(mapslices(y -> onecold(y, l...), y, 1), 1)
|
||||
|
||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
||||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
|
||||
# Other
|
||||
|
||||
function accuracy(m, data)
|
||||
|
@ -1,5 +1,6 @@
|
||||
using DataFlow, MacroTools
|
||||
using Flux.Compiler: @net, graph, stack, squeeze, unsqueeze
|
||||
using Flux: stack, unsqueeze
|
||||
using Flux.Compiler: @net, graph
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
@net type Affine
|
||||
@ -79,7 +80,7 @@ end
|
||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
||||
@test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
||||
ru = Flux.Compiler.unroll(r, 3)
|
||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
||||
ru(unsqueeze(stack(squeeze.(xs, 1), 1), 1))[1] == squeeze.(ys, 1)
|
||||
end
|
||||
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user