Merge branch 'cl-docs' of https://github.com/FluxML/Flux.jl into cl-docs
This commit is contained in:
commit
ba92f9a140
|
@ -38,6 +38,40 @@ m = fmap(cu, m)
|
||||||
d(cu(rand(10)))
|
d(cu(rand(10)))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
However, if you create a customized model, `fmap` may not work out of the box.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> struct ActorCritic{A, C}
|
||||||
|
actor::A
|
||||||
|
critic::C
|
||||||
|
end
|
||||||
|
|
||||||
|
julia> m = ActorCritic(ones(2,2), ones(2))
|
||||||
|
ActorCritic{Array{Float64,2},Array{Float64,1}}([1.0 1.0; 1.0 1.0], [1.0, 1.0])
|
||||||
|
|
||||||
|
julia> fmap(cu, m)
|
||||||
|
ActorCritic{Array{Float64,2},Array{Float64,1}}([1.0 1.0; 1.0 1.0], [1.0, 1.0])
|
||||||
|
```
|
||||||
|
|
||||||
|
As you can see, nothing changed after `fmap(cu, m)`. The reason is that `Flux` doesn't know your customized model structure. To make it work as expected, you need the `@functor` macro.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> Flux.@functor ActorCritic
|
||||||
|
|
||||||
|
julia> fmap(cu, m)
|
||||||
|
ActorCritic{CuArray{Float32,2,Nothing},CuArray{Float32,1,Nothing}}(Float32[1.0 1.0; 1.0 1.0], Float32[1.0, 1.0])
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can see that the inner fields of `actor` and `critic` are transformed into `CuArray`. So what does the `@functor` macro do here? Basically, it will create a function like this:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
Flux.functor(m::ActorCritic) = (actor = m.actor, critic=m.critic), fields -> ActorCritic(fields...)
|
||||||
|
```
|
||||||
|
|
||||||
|
And the `functor` will be called recursively in `fmap`. As you can see, the result of `functor` contains two parts, a *destructure* part and a *reconstrucutre* part. The first part is to make the customized model structure into `trainable` data structure known to `Flux` (here is a `NamedTuple`). The goal is to turn `m` into `(actor=cu(ones(2,2)), critic=cu(ones(2)))`. The second part is to turn the result back into a `ActorCritic`, so that we can get `ActorCritic(cu(ones(2,2)),cu(ones(2)))`.
|
||||||
|
|
||||||
|
By default, the `@functor` macro will transform all the fields in your customized structure. In some cases, you may only want to transform several fields. Then you just specify those fields manually like `Flux.@functor ActorCritic (actor,)` (note that the fields part must be a tuple). And make sure the `ActorCritic(actor)` constructor is also implemented.
|
||||||
|
|
||||||
As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading `CuArrays` will cause it to move data to the GPU instead.
|
As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing, but loading `CuArrays` will cause it to move data to the GPU instead.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
|
@ -73,4 +107,4 @@ julia> x |> cpu
|
||||||
0.235164
|
0.235164
|
||||||
⋮
|
⋮
|
||||||
0.192538
|
0.192538
|
||||||
```
|
```
|
Loading…
Reference in New Issue