diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index c5cdfa69..5c0ba209 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -55,8 +55,19 @@ Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = using Random +Random.rand(x::Flux.Tracker.TrackedReal} = rand(typeof(x)) +Random.rand(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(T)) +Random.rand(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = rand(rng,typeof(x)) Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T)) + +Random.randn(x::Flux.Tracker.TrackedReal} = randn(typeof(x)) +Random.randn(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(T)) +Random.randn(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = randn(rng,typeof(x)) Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T)) + +Random.randexp(x::Flux.Tracker.TrackedReal} = randexp(typeof(x)) +Random.randexp(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(T)) +Random.randexp(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = randexp(rng,typeof(x)) Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T)) using DiffRules, SpecialFunctions, NaNMath