mapparams

This commit is contained in:
Mike J Innes 2019-01-25 10:11:46 +00:00
parent 791939709b
commit 2b1a3e92da
1 changed files with 10 additions and 0 deletions

View File

@ -73,3 +73,13 @@ paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
x isa Union{AbstractArray,Number} ? f(x) :
x
end
end