Merge pull request #226 from CarloLucibello/reshape

fix reshape
This commit is contained in:
Mike J Innes 2018-04-15 16:53:21 +01:00 committed by GitHub
commit 642543808e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 5 deletions

View File

@ -117,11 +117,9 @@ function back(::typeof(vcat), Δ, xs...)
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
track(reshape, xs, dims...)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
track(reshape, xs, dims)
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))

View File

@ -82,6 +82,21 @@ end
@test param(2)^2 == 4.0
@testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2)
@test x isa TrackedArray
@test size(x) == (4,2)
x = reshape(param([1]), (1,:))
@test x isa TrackedArray
@test size(x) == (1,1)
x = reshape(param(rand(2)), (2,:))
@test x isa TrackedArray
@test size(x) == (2,1)
x = reshape(param(rand(2,2)), (1,:,2))
@test x isa TrackedArray
@test size(x) == (1,2,2)
end
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)