pull out tuple utils

This commit is contained in:
Mike J Innes 2017-05-01 16:57:51 +01:00
parent 2934607115
commit 357f989de5
2 changed files with 26 additions and 16 deletions

View File

@ -1,3 +1,5 @@
using Flux: collectt, shapecheckt
struct AlterParam
param
load
@ -46,22 +48,10 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function collectt(xs)
ys = []
mapt(x -> push!(ys, x), xs)
return ys
end
function dictt(ks::Tuple, vs, d = Dict())
for i = 1:length(ks)
dictt(ks[i], vs[i], d)
end
return d
end
dictt(k, v, d = Dict()) = (d[k] = v; d)
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function executor(graph::Graph, input...)
shapecheckt(graph.input, input)
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
exec = mx.bind(mxgroup(graph.output),

View File

@ -1,5 +1,7 @@
export AArray, unsqueeze
# Arrays
const AArray = AbstractArray
initn(dims...) = randn(dims...)/100
@ -10,11 +12,29 @@ Base.squeeze(xs) = squeeze(xs, 1)
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
# Tuples
mapt(f, x) = f(x)
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
function collectt(xs)
ys = []
mapt(x -> push!(ys, x), xs)
return ys
end
function shapecheckt(xs::Tuple, ys::Tuple)
length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys")
shapecheckt.(xs, ys)
end
shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys")
shapecheckt(xs, ys) = nothing
# Other
function accuracy(m, data)
n = 0