From fb685291693cf0c9dbc466557af6d3ab7f078e88 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 08:37:30 +0200 Subject: [PATCH] define back function right after forward function --- src/tracker/array.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 89fce39e..71a2d530 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,19 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -for f in [:vcat, :hcat] - @eval begin - Base.$f(a::TrackedArray...) = track($f, a...) - Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) - - # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector - Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) - end -end - -Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) -Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) - function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -106,6 +93,16 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) back(xs, Δ′) end +for f in [:vcat, :hcat] + @eval begin + Base.$f(a::TrackedArray...) = track($f, a...) + Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) + + # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector + Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) + end +end + function back(::typeof(vcat), Δ, xs...) start = 0 for xsi in xs @@ -128,6 +125,9 @@ function back(::typeof(hcat), Δ, xs...) end end +Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) +Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) + function back(::typeof(cat), Δ, dim, xs...) start = 0 for xsi in xs