diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index 4abadfa3..bb2d8581 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -72,7 +72,7 @@ for (M, f, arity) in DiffRules.diffrules() f = :($M.$f) @eval begin @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * convert(eltype(Δ), $da), _zero(b)) + @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * convert(TrackedReal{eltype(Δ)}, $da), _zero(b)) @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b)