add some non-differentiable functions
This commit is contained in:
parent
387686eb41
commit
107d9daa8f
@ -27,6 +27,9 @@ Base.sum(xs::TrackedScalar, dim...) = xs
|
|||||||
|
|
||||||
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
|
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
|
||||||
|
|
||||||
|
Base.maximum(xs::TrackedArray, args...) = maximum(xs.x, args...)
|
||||||
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.x, args...)
|
||||||
|
|
||||||
# BLAS
|
# BLAS
|
||||||
|
|
||||||
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||||
|
Loading…
Reference in New Issue
Block a user