onehot
This commit is contained in:
parent
d7e3f7d6e1
commit
1855a37319
@ -20,9 +20,9 @@ include("optimise/Optimise.jl")
|
|||||||
using .Optimise
|
using .Optimise
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
include("onehot.jl")
|
||||||
|
|
||||||
include("compiler/Compiler.jl")
|
include("compiler/Compiler.jl")
|
||||||
using .Compiler: @net
|
|
||||||
|
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
|
31
src/onehot.jl
Normal file
31
src/onehot.jl
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
struct OneHotVector <: AbstractVector{Bool}
|
||||||
|
ix::UInt32
|
||||||
|
of::UInt32
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
||||||
|
|
||||||
|
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||||
|
|
||||||
|
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
||||||
|
|
||||||
|
struct OneHotMatrix <: AbstractMatrix{Bool}
|
||||||
|
data::Vector{OneHotVector}
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data))
|
||||||
|
|
||||||
|
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||||
|
|
||||||
|
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
||||||
|
|
||||||
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||||
|
|
||||||
|
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
||||||
|
onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls])
|
||||||
|
|
||||||
|
onecold(y::AbstractVector, labels = 1:length(y)) =
|
||||||
|
labels[findfirst(y, maximum(y))]
|
||||||
|
|
||||||
|
onecold(y::AbstractMatrix, l...) =
|
||||||
|
squeeze(mapslices(y -> onecold(y, l...), y, 1), 1)
|
29
src/utils.jl
29
src/utils.jl
@ -2,32 +2,13 @@
|
|||||||
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
|
|
||||||
"""
|
|
||||||
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
|
|
||||||
|
|
||||||
onehot(Float32, 'c', ['a', 'b', 'c', 'd']) => [0., 0., 1., 0.]
|
|
||||||
|
|
||||||
Produce a one-hot-encoded version of an item, given a list of possible values
|
|
||||||
for the item.
|
|
||||||
"""
|
|
||||||
onehot(T::Type, label, labels) = T[i == label for i in labels]
|
|
||||||
onehot(label, labels) = onehot(Int, label, labels)
|
|
||||||
|
|
||||||
"""
|
|
||||||
onecold([0.0, 1.0, 0.0, ...],
|
|
||||||
['a', 'b', 'c', ...]) => 'b'
|
|
||||||
|
|
||||||
The inverse of `onehot`; takes an output prediction vector and a list of
|
|
||||||
possible values, and produces the appropriate value.
|
|
||||||
"""
|
|
||||||
onecold(y::AbstractVector, labels = 1:length(y)) =
|
|
||||||
labels[findfirst(y, maximum(y))]
|
|
||||||
|
|
||||||
onecold(y::AbstractMatrix, l...) =
|
|
||||||
squeeze(mapslices(y -> onecold(y, l...), y, 1), 1)
|
|
||||||
|
|
||||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
flatten(xs) = reshape(xs, size(xs, 1), :)
|
||||||
|
|
||||||
|
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
|
||||||
|
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
||||||
|
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||||
|
|
||||||
# Other
|
# Other
|
||||||
|
|
||||||
function accuracy(m, data)
|
function accuracy(m, data)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using DataFlow, MacroTools
|
using DataFlow, MacroTools
|
||||||
using Flux.Compiler: @net, graph, stack, squeeze, unsqueeze
|
using Flux: stack, unsqueeze
|
||||||
|
using Flux.Compiler: @net, graph
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
@net type Affine
|
@net type Affine
|
||||||
@ -79,7 +80,7 @@ end
|
|||||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
||||||
@test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
@test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
||||||
ru = Flux.Compiler.unroll(r, 3)
|
ru = Flux.Compiler.unroll(r, 3)
|
||||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
ru(unsqueeze(stack(squeeze.(xs, 1), 1), 1))[1] == squeeze.(ys, 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user