Omega and Turing fix
This commit is contained in:
parent
cb773e54c0
commit
96dbae2d20
@ -60,14 +60,18 @@ for (M, f, arity) in DiffRules.diffrules()
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Work around zero(π) not working, for some reason
|
||||||
|
_zero(::Irrational) = nothing
|
||||||
|
_zero(x) = zero(x)
|
||||||
|
|
||||||
for (M, f, arity) in DiffRules.diffrules()
|
for (M, f, arity) in DiffRules.diffrules()
|
||||||
arity == 2 || continue
|
arity == 2 || continue
|
||||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
da, db = DiffRules.diffrule(M, f, :a, :b)
|
||||||
f = :($M.$f)
|
f = :($M.$f)
|
||||||
@eval begin
|
@eval begin
|
||||||
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||||
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, zero(b))
|
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
|
||||||
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (zero(a), Δ * $db)
|
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
|
||||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||||
|
Loading…
Reference in New Issue
Block a user