mapparams
This commit is contained in:
parent
791939709b
commit
2b1a3e92da
|
@ -73,3 +73,13 @@ paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
|||
|
||||
f32(m) = paramtype(Float32, m)
|
||||
f64(m) = paramtype(Float64, m)
|
||||
|
||||
# General parameter map
|
||||
|
||||
function mapparams(f, m)
|
||||
mapleaves(m) do x
|
||||
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
||||
x isa Union{AbstractArray,Number} ? f(x) :
|
||||
x
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue