add some non-differentiable functions
This commit is contained in:
parent
387686eb41
commit
107d9daa8f
@ -27,6 +27,9 @@ Base.sum(xs::TrackedScalar, dim...) = xs
|
||||
|
||||
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.x, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.x, args...)
|
||||
|
||||
# BLAS
|
||||
|
||||
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||
|
Loading…
Reference in New Issue
Block a user