docs for functor.jl
This commit is contained in:
parent
af99ca27ee
commit
94d95442ab
|
@ -37,7 +37,7 @@ include("layers/normalise.jl")
|
|||
|
||||
include("data/Data.jl")
|
||||
|
||||
include("deprecations.jl")
|
||||
include("deprecated.jl")
|
||||
|
||||
function __init__()
|
||||
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
|
||||
|
|
|
@ -1,2 +1,5 @@
|
|||
import Base: @deprecate
|
||||
|
||||
|
||||
@deprecate param(x) x
|
||||
@deprecate data(x) x
|
|
@ -1,6 +1,15 @@
|
|||
import Adapt: adapt, adapt_storage
|
||||
using Zygote: IdSet
|
||||
|
||||
"""
|
||||
functor(x) -> func, re
|
||||
|
||||
We have `x == re(func)`.
|
||||
Return `func = ()` and `re = _ -> x` for leaf objects.
|
||||
"""
|
||||
function functor end
|
||||
|
||||
# by default, every object is a leaf
|
||||
functor(x) = (), _ -> x
|
||||
|
||||
functor(x::Tuple) = x, y -> y
|
||||
|
@ -21,10 +30,35 @@ function functorm(T, fs = nothing)
|
|||
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
"""
|
||||
@functor T fields...
|
||||
|
||||
Given a type `T` and a subset of its fieldnames `fields`,
|
||||
create a [`functor`](@ref) function :
|
||||
|
||||
functor(x::T) -> func, re
|
||||
|
||||
where
|
||||
|
||||
func: (field1 = x.field1, field2 = x.field2, ....)
|
||||
|
||||
re: y -> T(y...)
|
||||
|
||||
If no `fields` argument is given, all internal fields will be considered.
|
||||
"""
|
||||
macro functor(args...)
|
||||
functorm(args...)
|
||||
end
|
||||
|
||||
"""
|
||||
isleaf(x)
|
||||
|
||||
Check if variable `x` is a *leaf* according to the definition:
|
||||
|
||||
isleaf(x) = functor(x)[1] === ()
|
||||
|
||||
See [`functor`](@ref).
|
||||
"""
|
||||
isleaf(x) = functor(x)[1] === ()
|
||||
|
||||
function fmap1(f, x)
|
||||
|
@ -32,6 +66,17 @@ function fmap1(f, x)
|
|||
re(map(f, func))
|
||||
end
|
||||
|
||||
"""
|
||||
fmap(f, m)
|
||||
|
||||
Applies function `f` to each leaf (see [`isleaf`](@ref)) in `m` and reconstructs
|
||||
`m` from the transformed leaves.
|
||||
|
||||
Example:
|
||||
|
||||
gpu(m) = fmap(CuArrays.cu, m)
|
||||
|
||||
"""
|
||||
function fmap(f, x; cache = IdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
|
||||
|
@ -81,9 +126,40 @@ function params!(p::Params, x, seen = IdSet())
|
|||
end
|
||||
end
|
||||
|
||||
function params(m...)
|
||||
"""
|
||||
params(x...)
|
||||
|
||||
Recursively scans the inputs for trainable params
|
||||
and collects them into a `Zygote.Params` object `ps`.
|
||||
|
||||
***Usage***
|
||||
|
||||
W = rand(5, 3)
|
||||
b = zeros(5)
|
||||
m = Dense(W, b)
|
||||
|
||||
ps = params(W, b)
|
||||
ps = params([W, b]) # equivalent form
|
||||
ps = params(m) # equivalent form
|
||||
|
||||
x = rand(3)
|
||||
y = rand(5)
|
||||
loss(W, b) = sum(((W*x + b) - y).^2)
|
||||
loss(m) = sum((m(x) - y).^2)
|
||||
|
||||
# Gradient computation.
|
||||
# Returns a tuple of 2 of arrays containing the gradients.
|
||||
gs = gradient((W, b) -> loss(W, b), W, b)
|
||||
|
||||
# Gradient behaves differently with Params.
|
||||
# ps is not fed as an argument to the loss.
|
||||
# Returns a Zygote.Grads object.
|
||||
gs = gradient(() -> loss(m), ps)
|
||||
|
||||
"""
|
||||
function params(x...)
|
||||
ps = Params()
|
||||
params!(ps, m)
|
||||
params!(ps, x)
|
||||
return ps
|
||||
end
|
||||
|
||||
|
@ -91,6 +167,8 @@ end
|
|||
macro treelike(args...)
|
||||
functorm(args...)
|
||||
end
|
||||
|
||||
|
||||
mapleaves(f, x) = fmap(f, x)
|
||||
|
||||
function loadparams!(m, xs)
|
||||
|
@ -102,10 +180,21 @@ function loadparams!(m, xs)
|
|||
end
|
||||
|
||||
# CPU/GPU movement conveniences
|
||||
"""
|
||||
cpu(m)
|
||||
|
||||
Move model or data `m` to the cpu. Makes
|
||||
copies only if needed.
|
||||
"""
|
||||
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||
|
||||
gpu(x) = use_cuda[] ? fmap(CuArrays.cu, x) : x
|
||||
"""
|
||||
gpu(m)
|
||||
|
||||
Move model or data `m` to the gpu device if available,
|
||||
otherwise do nothing. Makes copies only if needed.
|
||||
"""
|
||||
gpu(m) = use_cuda[] ? fmap(CuArrays.cu, m) : m
|
||||
|
||||
# Precision
|
||||
|
||||
|
|
Loading…
Reference in New Issue