This commit is contained in:
Mike J Innes 2019-01-11 10:06:14 +00:00
parent f6faa10ee2
commit aa1b4f410f

View File

@ -55,14 +55,9 @@ Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
using Random
Random.rand(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(T))
Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T))
Random.randn(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(T))
Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T))
Random.randexp(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(T))
Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T))
for f in :[rand, randn, randexp].args
@eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T))
end
using DiffRules, SpecialFunctions, NaNMath