Compare commits
5 Commits
Author | SHA1 | Date |
---|---|---|
![]() |
ba92f9a140 | |
![]() |
4516978caa | |
![]() |
19df897de7 | |
![]() |
94d95442ab | |
![]() |
64b4a6a80c |
|
@ -38,6 +38,40 @@ m = fmap(cu, m)
|
||||||
d(cu(rand(10)))
|
d(cu(rand(10)))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
However, if you create a customized model, `fmap` may not work out of the box.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> struct ActorCritic{A, C}
|
||||||
|
actor::A
|
||||||
|
critic::C
|
||||||
|
end
|
||||||
|
|
||||||
|
julia> m = ActorCritic(ones(2,2), ones(2))
|
||||||
|
ActorCritic{Array{Float64,2},Array{Float64,1}}([1.0 1.0; 1.0 1.0], [1.0, 1.0])
|
||||||
|
|
||||||
|
julia> fmap(cu, m)
|
||||||
|
ActorCritic{Array{Float64,2},Array{Float64,1}}([1.0 1.0; 1.0 1.0], [1.0, 1.0])
|
||||||
|
```
|
||||||
|
|
||||||
|
As you can see, nothing changed after `fmap(cu, m)`. The reason is that `Flux` doesn't know your customized model structure. To make it work as expected, you need the `@functor` macro.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> Flux.@functor ActorCritic
|
||||||
|
|
||||||
|
julia> fmap(cu, m)
|
||||||
|
ActorCritic{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}(Float32[1.0 1.0; 1.0 1.0], Float32[1.0, 1.0])
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can see that the inner fields of `actor` and `critic` are transformed into `CuArray`. So what does the `@functor` macro do here? Basically, it will create a function like this:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
Flux.functor(m::ActorCritic) = (actor = m.actor, critic=m.critic), fields -> ActorCritic(fields...)
|
||||||
|
```
|
||||||
|
|
||||||
|
And the `functor` will be called recursively in `fmap`. As you can see, the result of `functor` contains two parts, a *destructure* part and a *reconstrucutre* part. The first part is to make the customized model structure into `trainable` data structure known to `Flux` (here is a `NamedTuple`). The goal is to turn `m` into `(actor=cu(ones(2,2)), critic=cu(ones(2)))`. The second part is to turn the result back into a `ActorCritic`, so that we can get `ActorCritic(cu(ones(2,2)),cu(ones(2)))`.
|
||||||
|
|
||||||
|
By default, the `@functor` macro will transform all the fields in your customized structure. In some cases, you may only want to transform several fields. Then you just specify those fields manually like `Flux.@functor ActorCritic (actor,)` (note that the fields part must be a tuple). And make sure the `ActorCritic(actor)` constructor is also implemented.
|
||||||
|
|
||||||
As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading `CuArrays` will cause it to move data to the GPU instead.
|
As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading `CuArrays` will cause it to move data to the GPU instead.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
|
@ -73,4 +107,4 @@ julia> x |> cpu
|
||||||
0.235164
|
0.235164
|
||||||
⋮
|
⋮
|
||||||
0.192538
|
0.192538
|
||||||
```
|
```
|
|
@ -37,7 +37,7 @@ include("layers/normalise.jl")
|
||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
include("deprecations.jl")
|
include("deprecated.jl")
|
||||||
|
|
||||||
function __init__()
|
function __init__()
|
||||||
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
|
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
import Base: @deprecate
|
||||||
|
|
||||||
|
#### remove in v 0.11 #####
|
||||||
|
@deprecate param(x) x
|
||||||
|
@deprecate data(x) x
|
||||||
|
|
||||||
|
@deprecate mapleaves(f, x) fmap(f, x)
|
||||||
|
|
||||||
|
macro treelike(args...)
|
||||||
|
functorm(args...)
|
||||||
|
end
|
||||||
|
#############################
|
||||||
|
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
@deprecate param(x) x
|
|
||||||
@deprecate data(x) x
|
|
|
@ -1,6 +1,15 @@
|
||||||
import Adapt: adapt, adapt_storage
|
import Adapt: adapt, adapt_storage
|
||||||
using Zygote: IdSet
|
using Zygote: IdSet
|
||||||
|
|
||||||
|
"""
|
||||||
|
functor(x) -> func, re
|
||||||
|
|
||||||
|
We have `x == re(func)`.
|
||||||
|
Return `func = ()` and `re = _ -> x` for leaf objects.
|
||||||
|
"""
|
||||||
|
function functor end
|
||||||
|
|
||||||
|
# by default, every object is a leaf
|
||||||
functor(x) = (), _ -> x
|
functor(x) = (), _ -> x
|
||||||
|
|
||||||
functor(x::Tuple) = x, y -> y
|
functor(x::Tuple) = x, y -> y
|
||||||
|
@ -21,10 +30,35 @@ function functorm(T, fs = nothing)
|
||||||
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
@functor T fields...
|
||||||
|
|
||||||
|
Given a type `T` and a subset of its fieldnames `fields`,
|
||||||
|
create a [`functor`](@ref) function :
|
||||||
|
|
||||||
|
functor(x::T) -> func, re
|
||||||
|
|
||||||
|
where
|
||||||
|
|
||||||
|
func: (field1 = x.field1, field2 = x.field2, ....)
|
||||||
|
|
||||||
|
re: y -> T(y...)
|
||||||
|
|
||||||
|
If no `fields` argument is given, all internal fields will be considered.
|
||||||
|
"""
|
||||||
macro functor(args...)
|
macro functor(args...)
|
||||||
functorm(args...)
|
functorm(args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
isleaf(x)
|
||||||
|
|
||||||
|
Check if variable `x` is a *leaf* according to the definition:
|
||||||
|
|
||||||
|
isleaf(x) = functor(x)[1] === ()
|
||||||
|
|
||||||
|
See [`functor`](@ref).
|
||||||
|
"""
|
||||||
isleaf(x) = functor(x)[1] === ()
|
isleaf(x) = functor(x)[1] === ()
|
||||||
|
|
||||||
function fmap1(f, x)
|
function fmap1(f, x)
|
||||||
|
@ -32,6 +66,17 @@ function fmap1(f, x)
|
||||||
re(map(f, func))
|
re(map(f, func))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
fmap(f, m)
|
||||||
|
|
||||||
|
Applies function `f` to each leaf (see [`isleaf`](@ref)) in `m` and reconstructs
|
||||||
|
`m` from the transformed leaves.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
gpu(m) = fmap(CuArrays.cu, m)
|
||||||
|
|
||||||
|
"""
|
||||||
function fmap(f, x; cache = IdDict())
|
function fmap(f, x; cache = IdDict())
|
||||||
haskey(cache, x) && return cache[x]
|
haskey(cache, x) && return cache[x]
|
||||||
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
|
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
|
||||||
|
@ -81,18 +126,43 @@ function params!(p::Params, x, seen = IdSet())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function params(m...)
|
"""
|
||||||
|
params(x...)
|
||||||
|
|
||||||
|
Recursively scans the inputs for trainable params
|
||||||
|
and collects them into a `Zygote.Params` object `ps`.
|
||||||
|
|
||||||
|
***Usage***
|
||||||
|
|
||||||
|
W = rand(5, 3)
|
||||||
|
b = zeros(5)
|
||||||
|
m = Dense(W, b)
|
||||||
|
|
||||||
|
ps = params(W, b)
|
||||||
|
ps = params([W, b]) # equivalent form
|
||||||
|
ps = params(m) # equivalent form
|
||||||
|
|
||||||
|
x = rand(3)
|
||||||
|
y = rand(5)
|
||||||
|
loss(W, b) = sum(((W*x + b) - y).^2)
|
||||||
|
loss(m) = sum((m(x) - y).^2)
|
||||||
|
|
||||||
|
# Gradient computation.
|
||||||
|
# Returns a tuple of 2 of arrays containing the gradients.
|
||||||
|
gs = gradient((W, b) -> loss(W, b), W, b)
|
||||||
|
|
||||||
|
# Gradient behaves differently with Params.
|
||||||
|
# ps is not fed as an argument to the loss.
|
||||||
|
# Returns a Zygote.Grads object.
|
||||||
|
gs = gradient(() -> loss(m), ps)
|
||||||
|
|
||||||
|
"""
|
||||||
|
function params(x...)
|
||||||
ps = Params()
|
ps = Params()
|
||||||
params!(ps, m)
|
params!(ps, x)
|
||||||
return ps
|
return ps
|
||||||
end
|
end
|
||||||
|
|
||||||
# Deprecated stuff
|
|
||||||
macro treelike(args...)
|
|
||||||
functorm(args...)
|
|
||||||
end
|
|
||||||
mapleaves(f, x) = fmap(f, x)
|
|
||||||
|
|
||||||
function loadparams!(m, xs)
|
function loadparams!(m, xs)
|
||||||
for (p, x) in zip(params(m), xs)
|
for (p, x) in zip(params(m), xs)
|
||||||
size(p) == size(x) ||
|
size(p) == size(x) ||
|
||||||
|
@ -102,10 +172,21 @@ function loadparams!(m, xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
# CPU/GPU movement conveniences
|
# CPU/GPU movement conveniences
|
||||||
|
"""
|
||||||
|
cpu(m)
|
||||||
|
|
||||||
|
Move model or data `m` to the cpu. Makes
|
||||||
|
copies only if needed.
|
||||||
|
"""
|
||||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||||
|
|
||||||
gpu(x) = use_cuda[] ? fmap(CuArrays.cu, x) : x
|
"""
|
||||||
|
gpu(m)
|
||||||
|
|
||||||
|
Move model or data `m` to the gpu device if available,
|
||||||
|
otherwise do nothing. Makes copies only if needed.
|
||||||
|
"""
|
||||||
|
gpu(m) = use_cuda[] ? fmap(CuArrays.cu, m) : m
|
||||||
|
|
||||||
# Precision
|
# Precision
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue