Don't extend base functions on base types
better broadcast syntax
This commit is contained in:
parent
cb4d8cf9a6
commit
5cbb47a13d
@ -9,7 +9,6 @@ const AArray = AbstractArray
|
|||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
|
|
||||||
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
Base.squeeze(xs) = squeeze(xs, 1)
|
|
||||||
|
|
||||||
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
||||||
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||||
|
@ -13,5 +13,5 @@ end
|
|||||||
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
|
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
|
||||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||||
ru = unroll(r, 3)
|
ru = unroll(r, 3)
|
||||||
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
|
ru(batchone(Seq(squeeze.(xs, 1))))[1] == squeeze.(ys, 1)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user