From d25e05d9eed5cde043a609bf6aca63bc545ee6b5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 27 Sep 2018 10:40:44 +0200 Subject: [PATCH] evaluate both 2-ary DiffRules only when needed --- src/tracker/scalar.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 81ccb9a3..1b6098fb 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -63,7 +63,9 @@ for (M, f, arity) in DiffRules.diffrules() da, db = DiffRules.diffrule(M, f, :a, :b) f = :($M.$f) @eval begin - @grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, zero(b)) + @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (zero(a), Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b)