mapparams
This commit is contained in:
parent
791939709b
commit
2b1a3e92da
@ -73,3 +73,13 @@ paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
|||||||
|
|
||||||
f32(m) = paramtype(Float32, m)
|
f32(m) = paramtype(Float32, m)
|
||||||
f64(m) = paramtype(Float64, m)
|
f64(m) = paramtype(Float64, m)
|
||||||
|
|
||||||
|
# General parameter map
|
||||||
|
|
||||||
|
function mapparams(f, m)
|
||||||
|
mapleaves(m) do x
|
||||||
|
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
||||||
|
x isa Union{AbstractArray,Number} ? f(x) :
|
||||||
|
x
|
||||||
|
end
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user