From cf061e9207918eff0793d93d1eca78fe878c2071 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Jan 2019 23:04:12 -0800 Subject: [PATCH] support random numbers as constants --- src/tracker/lib/real.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index 146706c7..c5cdfa69 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -53,6 +53,12 @@ Base.float(x::TrackedReal) = x Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = TrackedReal{promote_type(S,T)} +using Random + +Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T)) +Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T)) +Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T)) + using DiffRules, SpecialFunctions, NaNMath for (M, f, arity) in DiffRules.diffrules()