2017-05-01 11:41:54 +00:00
|
|
|
export AArray, unsqueeze
|
2016-08-22 20:13:28 +00:00
|
|
|
|
2017-05-04 09:32:53 +00:00
|
|
|
call(f, xs...) = f(xs...)
|
|
|
|
|
2017-05-01 15:57:51 +00:00
|
|
|
# Arrays
|
|
|
|
|
2016-08-22 20:13:28 +00:00
|
|
|
const AArray = AbstractArray
|
2016-04-01 21:11:42 +00:00
|
|
|
|
2017-02-02 04:39:41 +00:00
|
|
|
initn(dims...) = randn(dims...)/100
|
2016-08-25 16:25:33 +00:00
|
|
|
|
2017-05-01 11:41:54 +00:00
|
|
|
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
2017-05-22 08:08:46 +00:00
|
|
|
squeeze(xs, dim = 1) = Base.squeeze(xs, dim)
|
2017-04-19 13:23:48 +00:00
|
|
|
|
2017-05-01 11:41:54 +00:00
|
|
|
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
|
|
|
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
2016-08-23 22:58:39 +00:00
|
|
|
|
2017-05-01 15:57:51 +00:00
|
|
|
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
|
|
|
|
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
|
|
|
|
|
|
|
# Tuples
|
|
|
|
|
2017-05-01 11:41:54 +00:00
|
|
|
mapt(f, x) = f(x)
|
|
|
|
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
|
|
|
|
2017-05-01 15:57:51 +00:00
|
|
|
function collectt(xs)
|
|
|
|
ys = []
|
|
|
|
mapt(x -> push!(ys, x), xs)
|
|
|
|
return ys
|
|
|
|
end
|
|
|
|
|
|
|
|
function shapecheckt(xs::Tuple, ys::Tuple)
|
|
|
|
length(xs) == length(ys) || error("Expected tuple length $(length(xs)), got $ys")
|
|
|
|
shapecheckt.(xs, ys)
|
|
|
|
end
|
|
|
|
|
|
|
|
shapecheckt(xs::Tuple, ys) = error("Expected tuple, got $ys")
|
|
|
|
shapecheckt(xs, ys) = nothing
|
|
|
|
|
|
|
|
# Other
|
2017-05-01 12:46:23 +00:00
|
|
|
|
|
|
|
function accuracy(m, data)
|
|
|
|
n = 0
|
|
|
|
correct = 0
|
|
|
|
for (x, y) in data
|
|
|
|
x, y = tobatch.((x, y))
|
|
|
|
n += size(x, 1)
|
|
|
|
correct += sum(onecold(m(x)) .== onecold(y))
|
|
|
|
end
|
|
|
|
return correct/n
|
|
|
|
end
|