2016-08-23 15:32:19 +00:00
|
|
|
|
export Model, back!, update!, param
|
|
|
|
|
|
|
|
|
|
# Basic model API
|
|
|
|
|
|
2016-12-15 22:31:27 +00:00
|
|
|
|
"""
|
|
|
|
|
(m::Model)(X...) => Y
|
|
|
|
|
|
|
|
|
|
A "model" is a function with state. For example, a logistic regression is the
|
|
|
|
|
function
|
|
|
|
|
|
|
|
|
|
x -> σ(x * W + b)
|
|
|
|
|
|
|
|
|
|
where `W` and `b` are a trainable matrix and vector of weights repectively. The
|
|
|
|
|
`Model` abstract type is used loosely; in general the concept of a model is
|
|
|
|
|
closer to a protocol, and models don't need to inherit from this type. Normal
|
|
|
|
|
Julia functions are models with 0 parameters, for example.
|
|
|
|
|
"""
|
2017-03-14 16:51:31 +00:00
|
|
|
|
abstract type Model end
|
2016-08-23 15:32:19 +00:00
|
|
|
|
|
2016-12-15 22:31:27 +00:00
|
|
|
|
"""
|
|
|
|
|
back!(m::Model, ΔY, X...) => ΔX
|
|
|
|
|
|
|
|
|
|
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
|
|
|
|
|
gradients of any parameters. Returns the gradient of the input `X`. Gradients
|
|
|
|
|
may be arrays or tuples of arrays (for multiple inputs/outputs).
|
|
|
|
|
"""
|
|
|
|
|
back!(m::Model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
update!(m::Model, η) => m
|
|
|
|
|
|
|
|
|
|
Update the parameters of the model `m` using the accumulated gradients from
|
|
|
|
|
`back!`, using the learning rate `η`.
|
|
|
|
|
"""
|
2016-08-24 14:41:30 +00:00
|
|
|
|
update!(m, η) = m
|
2016-08-23 15:32:19 +00:00
|
|
|
|
|
2016-12-15 22:31:27 +00:00
|
|
|
|
"""
|
|
|
|
|
graph(m::Model) => ::IVertex{Any} | nothing
|
|
|
|
|
|
|
|
|
|
Returns the graph representation of the model, if any. Most models are built
|
|
|
|
|
from lower-level components and can simply implement this method to get most of
|
|
|
|
|
Flux's functionality. If this method isn't available, functionality like
|
|
|
|
|
backpropagation or conversion for backend must be implemented on a case-by-case
|
|
|
|
|
basis. Alternatively, one can implement this method and override individual
|
|
|
|
|
methods as necessary.
|
|
|
|
|
"""
|
2016-08-31 01:37:53 +00:00
|
|
|
|
graph(m) = nothing
|
|
|
|
|
|
2017-03-21 01:32:12 +00:00
|
|
|
|
"""
|
|
|
|
|
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
|
|
|
|
|
unlike direct calling, it does not try to apply batching and simply uses the
|
|
|
|
|
inputs directly.
|
|
|
|
|
|
|
|
|
|
This function should be considered an implementation detail; it will be
|
|
|
|
|
eventually be replaced by a non-hacky way of doing batching.
|
|
|
|
|
"""
|
|
|
|
|
function runmodel end
|
|
|
|
|
|
2016-08-23 15:32:19 +00:00
|
|
|
|
# Model parameters
|
|
|
|
|
|
2017-03-08 15:36:25 +00:00
|
|
|
|
# TODO: should be AbstractArray?
|
2016-12-15 22:31:27 +00:00
|
|
|
|
"""
|
|
|
|
|
A `Param` object stores a parameter array along with an accumulated delta to
|
|
|
|
|
that array. When converting to backends like TensorFlow, identical `Param`s will
|
|
|
|
|
result in identical variable objects, making model reuse trivial.
|
|
|
|
|
"""
|
2017-03-14 17:56:03 +00:00
|
|
|
|
struct Param{T}
|
2016-08-23 15:32:19 +00:00
|
|
|
|
x::T
|
|
|
|
|
Δx::T
|
|
|
|
|
end
|
|
|
|
|
|
2016-12-15 22:31:27 +00:00
|
|
|
|
"""
|
|
|
|
|
param(x::T) => ::Param{T}
|
|
|
|
|
|
|
|
|
|
Convenience method for creating a `Param` object for a given array.
|
|
|
|
|
"""
|
2016-08-23 15:32:19 +00:00
|
|
|
|
param(x) = Param(x, zero(x))
|
|
|
|
|
|
|
|
|
|
state(p::Param) = p.x
|
|
|
|
|
|
2016-12-15 22:31:27 +00:00
|
|
|
|
"""
|
|
|
|
|
accumulate!(p::Param, Δ) => p
|
|
|
|
|
|
|
|
|
|
Accumulates the update `Δ` on `p`. The value of `p` won't change until
|
|
|
|
|
`update!`.
|
|
|
|
|
"""
|
2016-08-23 15:32:19 +00:00
|
|
|
|
function accumulate!(p::Param, Δ)
|
2016-12-15 22:57:36 +00:00
|
|
|
|
p.Δx += Δ
|
2016-08-23 15:32:19 +00:00
|
|
|
|
return p
|
|
|
|
|
end
|
|
|
|
|
|
2016-12-15 22:31:27 +00:00
|
|
|
|
"""
|
|
|
|
|
update!(p::Param)
|
|
|
|
|
|
|
|
|
|
Apply the accumulated updates to the value of the parameter.
|
|
|
|
|
"""
|
2016-08-23 15:32:19 +00:00
|
|
|
|
function update!(p::Param, η)
|
2016-08-24 14:41:17 +00:00
|
|
|
|
p.x .-= p.Δx .* η
|
2016-08-23 22:56:31 +00:00
|
|
|
|
p.Δx[:] = 0
|
2016-08-23 15:32:19 +00:00
|
|
|
|
return p
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
state(x) = x
|
|
|
|
|
accumulate!(x, Δ) = x
|
|
|
|
|
|
2017-02-20 23:15:27 +00:00
|
|
|
|
Base.size(p::Param) = size(p.x)
|
|
|
|
|
Base.size(p::Param, n) = size(p.x, n)
|
2016-10-04 21:23:10 +00:00
|
|
|
|
|
2016-11-14 15:42:29 +00:00
|
|
|
|
function Base.show(io::IO, p::Param)
|
|
|
|
|
print(io, "Param", size(p.x))
|
|
|
|
|
end
|
|
|
|
|
|
2017-03-08 15:36:25 +00:00
|
|
|
|
Base.copy!(xs, p::Param) = copy!(xs, p.x)
|
|
|
|
|
Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
|
|
|
|
|
2016-08-23 15:32:19 +00:00
|
|
|
|
# Anonymous models
|
|
|
|
|
|
|
|
|
|
export Capacitor
|
|
|
|
|
|
2017-03-14 17:56:03 +00:00
|
|
|
|
struct Capacitor <: Model
|
2016-08-23 15:32:19 +00:00
|
|
|
|
graph::IVertex{Any}
|
|
|
|
|
end
|
|
|
|
|
|
2017-03-20 19:57:00 +00:00
|
|
|
|
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
2016-08-23 15:32:19 +00:00
|
|
|
|
|
|
|
|
|
graph(cap::Capacitor) = cap.graph
|
2017-03-21 01:32:12 +00:00
|
|
|
|
|
|
|
|
|
# Recurrent Models
|
|
|
|
|
|
2017-04-26 16:42:47 +00:00
|
|
|
|
mutable struct Stateful <: Model
|
2017-03-21 01:32:12 +00:00
|
|
|
|
model
|
2017-04-26 16:42:47 +00:00
|
|
|
|
istate::Vector{Any}
|
|
|
|
|
ostate::Vector{Any}
|
2017-03-21 01:32:12 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-04-26 16:42:47 +00:00
|
|
|
|
Stateful(model, state) = Stateful(model, state, state)
|
|
|
|
|
|
2017-03-21 01:32:12 +00:00
|
|
|
|
function (m::Stateful)(x)
|
2017-04-26 16:42:47 +00:00
|
|
|
|
m.istate = m.ostate
|
|
|
|
|
state, y = runmodel(m.model, (m.istate...,), x)
|
|
|
|
|
m.ostate = collect(state)
|
2017-04-18 20:04:21 +00:00
|
|
|
|
return y
|
2017-03-21 01:32:12 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-04-26 16:42:47 +00:00
|
|
|
|
function back!(m::Stateful, Δ, x)
|
|
|
|
|
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
|
|
|
|
end
|
|
|
|
|
|
2017-04-27 16:27:08 +00:00
|
|
|
|
update!(m::Stateful, η) = update!(m.model, η)
|
|
|
|
|
|
2017-03-29 17:30:28 +00:00
|
|
|
|
stateless(m) = m
|
|
|
|
|
stateless(m::Stateful) = m.model
|
|
|
|
|
|
2017-03-29 18:25:50 +00:00
|
|
|
|
struct SeqModel <: Model
|
2017-03-21 01:32:12 +00:00
|
|
|
|
model
|
|
|
|
|
steps::Int
|
|
|
|
|
end
|
|
|
|
|
|
2017-04-19 16:33:55 +00:00
|
|
|
|
runseq(f, xs::Tuple...) = f(xs...)
|
|
|
|
|
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
|
|
|
|
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
|
|
|
|
|
|
|
|
|
function (m::SeqModel)(x)
|
|
|
|
|
runseq(x) do x
|
|
|
|
|
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
|
|
|
|
m.model(x)
|
|
|
|
|
end
|
|
|
|
|
end
|
2017-03-28 18:54:32 +00:00
|
|
|
|
|
2017-04-26 16:42:47 +00:00
|
|
|
|
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
2017-04-27 16:27:08 +00:00
|
|
|
|
|
|
|
|
|
update!(m::SeqModel, η) = update!(m.model, η)
|