diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 2bb33202..cdeb1b2d 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,3 +1,5 @@ +using Flux: collectt, shapecheckt + struct AlterParam param load @@ -46,22 +48,10 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...) mxungroup(x, outs) = copy(shift!(outs)) mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) -function collectt(xs) - ys = [] - mapt(x -> push!(ys, x), xs) - return ys -end - -function dictt(ks::Tuple, vs, d = Dict()) - for i = 1:length(ks) - dictt(ks[i], vs[i], d) - end - return d -end - -dictt(k, v, d = Dict()) = (d[k] = v; d) +dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) function executor(graph::Graph, input...) + shapecheckt(graph.input, input) args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input))) grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input))) exec = mx.bind(mxgroup(graph.output), diff --git a/src/utils.jl b/src/utils.jl index daac39cf..874cb9b2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,7 @@ export AArray, unsqueeze +# Arrays + const AArray = AbstractArray initn(dims...) = randn(dims...)/100 @@ -10,11 +12,29 @@ Base.squeeze(xs) = squeeze(xs, 1) stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] +convertel(T::Type, xs::AbstractArray) = convert.(T, xs) +convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs + +# Tuples + mapt(f, x) = f(x) mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) -convertel(T::Type, xs::AbstractArray) = convert.(T, xs) -convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs +function collectt(xs) + ys = [] + mapt(x -> push!(ys, x), xs) + return ys +end + +function shapecheckt(xs::Tuple, ys::Tuple) + length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys") + shapecheckt.(xs, ys) +end + +shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys") +shapecheckt(xs, ys) = nothing + +# Other function accuracy(m, data) n = 0