Flux.jl/src/model.jl

84 lines
1.9 KiB
Julia
Raw Normal View History

2016-12-15 22:31:27 +00:00
"""
2017-06-05 15:08:23 +00:00
back!(model, ΔY, X...) => ΔX
2016-12-15 22:31:27 +00:00
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
gradients of any parameters. Returns the gradient of the input `X`. Gradients
may be arrays or tuples of arrays (for multiple inputs/outputs).
"""
2017-06-05 15:08:23 +00:00
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
2016-12-15 22:31:27 +00:00
"""
2017-06-05 15:08:23 +00:00
update!(model, η) => m
2016-12-15 22:31:27 +00:00
Update the parameters of the model `m` using the accumulated gradients from
`back!`, using the learning rate `η`.
"""
2016-08-24 14:41:30 +00:00
update!(m, η) = m
2016-08-23 15:32:19 +00:00
2016-12-15 22:31:27 +00:00
"""
2017-06-05 15:08:23 +00:00
graph(model) => ::IVertex{Any} | nothing
2016-12-15 22:31:27 +00:00
Returns the graph representation of the model, if any. Most models are built
from lower-level components and can simply implement this method to get most of
Flux's functionality. If this method isn't available, functionality like
backpropagation or conversion for backend must be implemented on a case-by-case
basis. Alternatively, one can implement this method and override individual
methods as necessary.
"""
2016-08-31 01:37:53 +00:00
graph(m) = nothing
2016-08-23 15:32:19 +00:00
# Model parameters
2017-03-08 15:36:25 +00:00
# TODO: should be AbstractArray?
2016-12-15 22:31:27 +00:00
"""
2017-05-01 18:44:26 +00:00
A `Param` object stores a parameter array along with its gradient.
When converting to backends like TensorFlow, identical `Param`s will
result in identical variable objects.
2016-12-15 22:31:27 +00:00
"""
2017-03-14 17:56:03 +00:00
struct Param{T}
2016-08-23 15:32:19 +00:00
x::T
Δx::T
end
2016-12-15 22:31:27 +00:00
"""
param(x::T) => ::Param{T}
Convenience method for creating a `Param` object for a given array.
"""
2016-08-23 15:32:19 +00:00
param(x) = Param(x, zero(x))
state(p::Param) = p.x
2016-12-15 22:31:27 +00:00
"""
update!(p::Param)
Apply the accumulated updates to the value of the parameter.
"""
2016-08-23 15:32:19 +00:00
function update!(p::Param, η)
2016-08-24 14:41:17 +00:00
p.x .-= p.Δx .* η
2016-08-23 22:56:31 +00:00
p.Δx[:] = 0
2016-08-23 15:32:19 +00:00
return p
end
state(x) = x
2017-02-20 23:15:27 +00:00
Base.size(p::Param) = size(p.x)
Base.size(p::Param, n) = size(p.x, n)
2016-10-04 21:23:10 +00:00
2016-11-14 15:42:29 +00:00
function Base.show(io::IO, p::Param)
print(io, "Param", size(p.x))
end
2017-03-08 15:36:25 +00:00
Base.copy!(xs, p::Param) = copy!(xs, p.x)
Base.copy!(p::Param, xs) = copy!(p.x, xs)
2016-08-23 15:32:19 +00:00
# Anonymous models
2017-06-05 15:08:23 +00:00
struct Capacitor
2016-08-23 15:32:19 +00:00
graph::IVertex{Any}
end
2017-03-20 19:57:00 +00:00
(m::Capacitor)(xs...) = interpmodel(m, xs...)
2016-08-23 15:32:19 +00:00
graph(cap::Capacitor) = cap.graph