pull out tuple utils
This commit is contained in:
parent
2934607115
commit
357f989de5
@ -1,3 +1,5 @@
|
|||||||
|
using Flux: collectt, shapecheckt
|
||||||
|
|
||||||
struct AlterParam
|
struct AlterParam
|
||||||
param
|
param
|
||||||
load
|
load
|
||||||
@ -46,22 +48,10 @@ mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
|||||||
mxungroup(x, outs) = copy(shift!(outs))
|
mxungroup(x, outs) = copy(shift!(outs))
|
||||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||||
|
|
||||||
function collectt(xs)
|
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
||||||
ys = []
|
|
||||||
mapt(x -> push!(ys, x), xs)
|
|
||||||
return ys
|
|
||||||
end
|
|
||||||
|
|
||||||
function dictt(ks::Tuple, vs, d = Dict())
|
|
||||||
for i = 1:length(ks)
|
|
||||||
dictt(ks[i], vs[i], d)
|
|
||||||
end
|
|
||||||
return d
|
|
||||||
end
|
|
||||||
|
|
||||||
dictt(k, v, d = Dict()) = (d[k] = v; d)
|
|
||||||
|
|
||||||
function executor(graph::Graph, input...)
|
function executor(graph::Graph, input...)
|
||||||
|
shapecheckt(graph.input, input)
|
||||||
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||||
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||||
exec = mx.bind(mxgroup(graph.output),
|
exec = mx.bind(mxgroup(graph.output),
|
||||||
|
24
src/utils.jl
24
src/utils.jl
@ -1,5 +1,7 @@
|
|||||||
export AArray, unsqueeze
|
export AArray, unsqueeze
|
||||||
|
|
||||||
|
# Arrays
|
||||||
|
|
||||||
const AArray = AbstractArray
|
const AArray = AbstractArray
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
@ -10,11 +12,29 @@ Base.squeeze(xs) = squeeze(xs, 1)
|
|||||||
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
||||||
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||||
|
|
||||||
|
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
|
||||||
|
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
||||||
|
|
||||||
|
# Tuples
|
||||||
|
|
||||||
mapt(f, x) = f(x)
|
mapt(f, x) = f(x)
|
||||||
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
||||||
|
|
||||||
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
|
function collectt(xs)
|
||||||
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
ys = []
|
||||||
|
mapt(x -> push!(ys, x), xs)
|
||||||
|
return ys
|
||||||
|
end
|
||||||
|
|
||||||
|
function shapecheckt(xs::Tuple, ys::Tuple)
|
||||||
|
length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys")
|
||||||
|
shapecheckt.(xs, ys)
|
||||||
|
end
|
||||||
|
|
||||||
|
shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys")
|
||||||
|
shapecheckt(xs, ys) = nothing
|
||||||
|
|
||||||
|
# Other
|
||||||
|
|
||||||
function accuracy(m, data)
|
function accuracy(m, data)
|
||||||
n = 0
|
n = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user