From aa1b4f410f66c6be7fb518303c2d73b3bc9b97a6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 11 Jan 2019 10:06:14 +0000 Subject: [PATCH] simplify --- src/tracker/lib/real.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index b8285433..b2584cbe 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -55,14 +55,9 @@ Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = using Random -Random.rand(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(T)) -Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T)) - -Random.randn(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(T)) -Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T)) - -Random.randexp(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(T)) -Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T)) +for f in :[rand, randn, randexp].args + @eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T)) +end using DiffRules, SpecialFunctions, NaNMath