some docstrings
This commit is contained in:
parent
62fd13bded
commit
1c21a860e2
60
src/model.jl
60
src/model.jl
@ -2,29 +2,87 @@ export Model, back!, update!, param
|
|||||||
|
|
||||||
# Basic model API
|
# Basic model API
|
||||||
|
|
||||||
|
"""
|
||||||
|
(m::Model)(X...) => Y
|
||||||
|
|
||||||
|
A "model" is a function with state. For example, a logistic regression is the
|
||||||
|
function
|
||||||
|
|
||||||
|
x -> σ(x * W + b)
|
||||||
|
|
||||||
|
where `W` and `b` are a trainable matrix and vector of weights repectively. The
|
||||||
|
`Model` abstract type is used loosely; in general the concept of a model is
|
||||||
|
closer to a protocol, and models don't need to inherit from this type. Normal
|
||||||
|
Julia functions are models with 0 parameters, for example.
|
||||||
|
"""
|
||||||
abstract Model
|
abstract Model
|
||||||
|
|
||||||
back!(m::Model, Δ) = error("Backprop not implemented for $(typeof(m))")
|
"""
|
||||||
|
back!(m::Model, ΔY, X...) => ΔX
|
||||||
|
|
||||||
|
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
|
||||||
|
gradients of any parameters. Returns the gradient of the input `X`. Gradients
|
||||||
|
may be arrays or tuples of arrays (for multiple inputs/outputs).
|
||||||
|
"""
|
||||||
|
back!(m::Model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
|
||||||
|
|
||||||
|
"""
|
||||||
|
update!(m::Model, η) => m
|
||||||
|
|
||||||
|
Update the parameters of the model `m` using the accumulated gradients from
|
||||||
|
`back!`, using the learning rate `η`.
|
||||||
|
"""
|
||||||
update!(m, η) = m
|
update!(m, η) = m
|
||||||
|
|
||||||
|
"""
|
||||||
|
graph(m::Model) => ::IVertex{Any} | nothing
|
||||||
|
|
||||||
|
Returns the graph representation of the model, if any. Most models are built
|
||||||
|
from lower-level components and can simply implement this method to get most of
|
||||||
|
Flux's functionality. If this method isn't available, functionality like
|
||||||
|
backpropagation or conversion for backend must be implemented on a case-by-case
|
||||||
|
basis. Alternatively, one can implement this method and override individual
|
||||||
|
methods as necessary.
|
||||||
|
"""
|
||||||
graph(m) = nothing
|
graph(m) = nothing
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
|
|
||||||
|
"""
|
||||||
|
A `Param` object stores a parameter array along with an accumulated delta to
|
||||||
|
that array. When converting to backends like TensorFlow, identical `Param`s will
|
||||||
|
result in identical variable objects, making model reuse trivial.
|
||||||
|
"""
|
||||||
type Param{T}
|
type Param{T}
|
||||||
x::T
|
x::T
|
||||||
Δx::T
|
Δx::T
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
param(x::T) => ::Param{T}
|
||||||
|
|
||||||
|
Convenience method for creating a `Param` object for a given array.
|
||||||
|
"""
|
||||||
param(x) = Param(x, zero(x))
|
param(x) = Param(x, zero(x))
|
||||||
|
|
||||||
state(p::Param) = p.x
|
state(p::Param) = p.x
|
||||||
|
|
||||||
|
"""
|
||||||
|
accumulate!(p::Param, Δ) => p
|
||||||
|
|
||||||
|
Accumulates the update `Δ` on `p`. The value of `p` won't change until
|
||||||
|
`update!`.
|
||||||
|
"""
|
||||||
function accumulate!(p::Param, Δ)
|
function accumulate!(p::Param, Δ)
|
||||||
p.Δx .+= Δ
|
p.Δx .+= Δ
|
||||||
return p
|
return p
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
update!(p::Param)
|
||||||
|
|
||||||
|
Apply the accumulated updates to the value of the parameter.
|
||||||
|
"""
|
||||||
function update!(p::Param, η)
|
function update!(p::Param, η)
|
||||||
p.x .-= p.Δx .* η
|
p.x .-= p.Δx .* η
|
||||||
p.Δx[:] = 0
|
p.Δx[:] = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user