pull out tuple utils
This commit is contained in:
parent
2934607115
commit
357f989de5
@ -1,3 +1,5 @@
|
||||
using Flux: collectt, shapecheckt
|
||||
|
||||
struct AlterParam
|
||||
param
|
||||
load
|
||||
@ -46,22 +48,10 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||
mxungroup(x, outs) = copy(shift!(outs))
|
||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||
|
||||
function collectt(xs)
|
||||
ys = []
|
||||
mapt(x -> push!(ys, x), xs)
|
||||
return ys
|
||||
end
|
||||
|
||||
function dictt(ks::Tuple, vs, d = Dict())
|
||||
for i = 1:length(ks)
|
||||
dictt(ks[i], vs[i], d)
|
||||
end
|
||||
return d
|
||||
end
|
||||
|
||||
dictt(k, v, d = Dict()) = (d[k] = v; d)
|
||||
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
||||
|
||||
function executor(graph::Graph, input...)
|
||||
shapecheckt(graph.input, input)
|
||||
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
exec = mx.bind(mxgroup(graph.output),
|
||||
|
24
src/utils.jl
24
src/utils.jl
@ -1,5 +1,7 @@
|
||||
export AArray, unsqueeze
|
||||
|
||||
# Arrays
|
||||
|
||||
const AArray = AbstractArray
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
@ -10,11 +12,29 @@ Base.squeeze(xs) = squeeze(xs, 1)
|
||||
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
|
||||
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
|
||||
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
||||
|
||||
# Tuples
|
||||
|
||||
mapt(f, x) = f(x)
|
||||
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
||||
|
||||
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
|
||||
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
||||
function collectt(xs)
|
||||
ys = []
|
||||
mapt(x -> push!(ys, x), xs)
|
||||
return ys
|
||||
end
|
||||
|
||||
function shapecheckt(xs::Tuple, ys::Tuple)
|
||||
length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys")
|
||||
shapecheckt.(xs, ys)
|
||||
end
|
||||
|
||||
shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys")
|
||||
shapecheckt(xs, ys) = nothing
|
||||
|
||||
# Other
|
||||
|
||||
function accuracy(m, data)
|
||||
n = 0
|
||||
|
Loading…
Reference in New Issue
Block a user