support random numbers as constants
This commit is contained in:
parent
9781f063aa
commit
cf061e9207
@ -53,6 +53,12 @@ Base.float(x::TrackedReal) = x
|
|||||||
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
||||||
TrackedReal{promote_type(S,T)}
|
TrackedReal{promote_type(S,T)}
|
||||||
|
|
||||||
|
using Random
|
||||||
|
|
||||||
|
Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T))
|
||||||
|
Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T))
|
||||||
|
Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T))
|
||||||
|
|
||||||
using DiffRules, SpecialFunctions, NaNMath
|
using DiffRules, SpecialFunctions, NaNMath
|
||||||
|
|
||||||
for (M, f, arity) in DiffRules.diffrules()
|
for (M, f, arity) in DiffRules.diffrules()
|
||||||
|
Loading…
Reference in New Issue
Block a user