From cf061e9207918eff0793d93d1eca78fe878c2071 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Jan 2019 23:04:12 -0800 Subject: [PATCH 1/4] support random numbers as constants --- src/tracker/lib/real.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index 146706c7..c5cdfa69 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -53,6 +53,12 @@ Base.float(x::TrackedReal) = x Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = TrackedReal{promote_type(S,T)} +using Random + +Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T)) +Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T)) +Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T)) + using DiffRules, SpecialFunctions, NaNMath for (M, f, arity) in DiffRules.diffrules() From 3ee5a9979470746858ab460fd415d3481a2bcb80 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Jan 2019 23:15:21 -0800 Subject: [PATCH 2/4] hit all possibilities --- src/tracker/lib/real.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index c5cdfa69..5c0ba209 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -55,8 +55,19 @@ Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = using Random +Random.rand(x::Flux.Tracker.TrackedReal} = rand(typeof(x)) +Random.rand(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(T)) +Random.rand(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = rand(rng,typeof(x)) Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T)) + +Random.randn(x::Flux.Tracker.TrackedReal} = randn(typeof(x)) +Random.randn(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(T)) +Random.randn(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = randn(rng,typeof(x)) Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T)) + +Random.randexp(x::Flux.Tracker.TrackedReal} = randexp(typeof(x)) +Random.randexp(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(T)) +Random.randexp(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = randexp(rng,typeof(x)) Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T)) using DiffRules, SpecialFunctions, NaNMath From f6faa10ee24581bbf087865a031c02e1c90331a3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 10 Jan 2019 08:57:10 -0800 Subject: [PATCH 3/4] remove non-type dispatches --- src/tracker/lib/real.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index 5c0ba209..b8285433 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -55,19 +55,13 @@ Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = using Random -Random.rand(x::Flux.Tracker.TrackedReal} = rand(typeof(x)) Random.rand(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(T)) -Random.rand(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = rand(rng,typeof(x)) Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T)) -Random.randn(x::Flux.Tracker.TrackedReal} = randn(typeof(x)) Random.randn(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(T)) -Random.randn(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = randn(rng,typeof(x)) Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T)) -Random.randexp(x::Flux.Tracker.TrackedReal} = randexp(typeof(x)) Random.randexp(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(T)) -Random.randexp(rng::AbstractRNG,x::Flux.Tracker.TrackedReal} = randexp(rng,typeof(x)) Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T)) using DiffRules, SpecialFunctions, NaNMath From aa1b4f410f66c6be7fb518303c2d73b3bc9b97a6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 11 Jan 2019 10:06:14 +0000 Subject: [PATCH 4/4] simplify --- src/tracker/lib/real.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index b8285433..b2584cbe 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -55,14 +55,9 @@ Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = using Random -Random.rand(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(T)) -Random.rand(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},rand(rng,T)) - -Random.randn(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(T)) -Random.randn(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randn(rng,T)) - -Random.randexp(::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(T)) -Random.randexp(rng::AbstractRNG,::Type{Flux.Tracker.TrackedReal{T}}) where {T} = convert(Flux.Tracker.TrackedReal{T},randexp(rng,T)) +for f in :[rand, randn, randexp].args + @eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T)) +end using DiffRules, SpecialFunctions, NaNMath