diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 57474933..71d93e88 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys) @back(ys, Δ[size(xs,1)+1:end, i...]) end +Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = + TrackedArray(Call(reshape, xs, dims...)) + +back(::typeof(reshape), Δ, xs::TrackedArray, _...) = + back(xs, reshape(Δ, size(xs))) + # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))