define back function right after forward function
This commit is contained in:
parent
509a2e59f6
commit
fb68529169
@ -81,19 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
|||||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||||
|
|
||||||
for f in [:vcat, :hcat]
|
|
||||||
@eval begin
|
|
||||||
Base.$f(a::TrackedArray...) = track($f, a...)
|
|
||||||
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)
|
|
||||||
|
|
||||||
# assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
|
|
||||||
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...)
|
|
||||||
Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...)
|
|
||||||
|
|
||||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||||
Δ′ = similar(xs.data)
|
Δ′ = similar(xs.data)
|
||||||
S = size(xs.data)
|
S = size(xs.data)
|
||||||
@ -106,6 +93,16 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
|||||||
back(xs, Δ′)
|
back(xs, Δ′)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
for f in [:vcat, :hcat]
|
||||||
|
@eval begin
|
||||||
|
Base.$f(a::TrackedArray...) = track($f, a...)
|
||||||
|
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)
|
||||||
|
|
||||||
|
# assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
|
||||||
|
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function back(::typeof(vcat), Δ, xs...)
|
function back(::typeof(vcat), Δ, xs...)
|
||||||
start = 0
|
start = 0
|
||||||
for xsi in xs
|
for xsi in xs
|
||||||
@ -128,6 +125,9 @@ function back(::typeof(hcat), Δ, xs...)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...)
|
||||||
|
Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...)
|
||||||
|
|
||||||
function back(::typeof(cat), Δ, dim, xs...)
|
function back(::typeof(cat), Δ, dim, xs...)
|
||||||
start = 0
|
start = 0
|
||||||
for xsi in xs
|
for xsi in xs
|
||||||
|
Loading…
Reference in New Issue
Block a user