This commit is contained in:
Mike J Innes 2017-09-06 18:58:55 -04:00
parent d7e3f7d6e1
commit 1855a37319
4 changed files with 40 additions and 27 deletions

View File

@ -20,9 +20,9 @@ include("optimise/Optimise.jl")
using .Optimise
include("utils.jl")
include("onehot.jl")
include("compiler/Compiler.jl")
using .Compiler: @net
include("layers/stateless.jl")
include("layers/basic.jl")

31
src/onehot.jl Normal file
View File

@ -0,0 +1,31 @@
struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
end
Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
struct OneHotMatrix <: AbstractMatrix{Bool}
data::Vector{OneHotVector}
end
Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls])
onecold(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))]
onecold(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> onecold(y, l...), y, 1), 1)

View File

@ -2,32 +2,13 @@
initn(dims...) = randn(dims...)/100
"""
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
onehot(Float32, 'c', ['a', 'b', 'c', 'd']) => [0., 0., 1., 0.]
Produce a one-hot-encoded version of an item, given a list of possible values
for the item.
"""
onehot(T::Type, label, labels) = T[i == label for i in labels]
onehot(label, labels) = onehot(Int, label, labels)
"""
onecold([0.0, 1.0, 0.0, ...],
['a', 'b', 'c', ...]) => 'b'
The inverse of `onehot`; takes an output prediction vector and a list of
possible values, and produces the appropriate value.
"""
onecold(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))]
onecold(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> onecold(y, l...), y, 1), 1)
flatten(xs) = reshape(xs, size(xs, 1), :)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
# Other
function accuracy(m, data)

View File

@ -1,5 +1,6 @@
using DataFlow, MacroTools
using Flux.Compiler: @net, graph, stack, squeeze, unsqueeze
using Flux: stack, unsqueeze
using Flux.Compiler: @net, graph
using DataFlow: Line, Frame
@net type Affine
@ -79,7 +80,7 @@ end
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
@test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
ru = Flux.Compiler.unroll(r, 3)
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
ru(unsqueeze(stack(squeeze.(xs, 1), 1), 1))[1] == squeeze.(ys, 1)
end
end