fix reshape
This commit is contained in:
parent
4320738d87
commit
b415333233
|
@ -117,11 +117,9 @@ function back(::typeof(vcat), Δ, xs...)
|
|||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||
track(reshape, xs, dims...)
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
|
||||
track(reshape, xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
|
|
|
@ -82,6 +82,21 @@ end
|
|||
|
||||
@test param(2)^2 == 4.0
|
||||
|
||||
@testset "reshape" begin
|
||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (4,2)
|
||||
x = reshape(param([1]), (1,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,1)
|
||||
x = reshape(param(rand(2)), (2,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (2,1)
|
||||
x = reshape(param(rand(2,2)), (1,:,2))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,2,2)
|
||||
end
|
||||
|
||||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
|
|
Loading…
Reference in New Issue