Flux.jl/src/model.jl
2017-04-26 17:42:47 +01:00

169 lines
4.0 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

export Model, back!, update!, param
# Basic model API
"""
(m::Model)(X...) => Y
A "model" is a function with state. For example, a logistic regression is the
function
x -> σ(x * W + b)
where `W` and `b` are a trainable matrix and vector of weights repectively. The
`Model` abstract type is used loosely; in general the concept of a model is
closer to a protocol, and models don't need to inherit from this type. Normal
Julia functions are models with 0 parameters, for example.
"""
abstract type Model end
"""
back!(m::Model, ΔY, X...) => ΔX
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
gradients of any parameters. Returns the gradient of the input `X`. Gradients
may be arrays or tuples of arrays (for multiple inputs/outputs).
"""
back!(m::Model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
"""
update!(m::Model, η) => m
Update the parameters of the model `m` using the accumulated gradients from
`back!`, using the learning rate `η`.
"""
update!(m, η) = m
"""
graph(m::Model) => ::IVertex{Any} | nothing
Returns the graph representation of the model, if any. Most models are built
from lower-level components and can simply implement this method to get most of
Flux's functionality. If this method isn't available, functionality like
backpropagation or conversion for backend must be implemented on a case-by-case
basis. Alternatively, one can implement this method and override individual
methods as necessary.
"""
graph(m) = nothing
"""
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
unlike direct calling, it does not try to apply batching and simply uses the
inputs directly.
This function should be considered an implementation detail; it will be
eventually be replaced by a non-hacky way of doing batching.
"""
function runmodel end
# Model parameters
# TODO: should be AbstractArray?
"""
A `Param` object stores a parameter array along with an accumulated delta to
that array. When converting to backends like TensorFlow, identical `Param`s will
result in identical variable objects, making model reuse trivial.
"""
struct Param{T}
x::T
Δx::T
end
"""
param(x::T) => ::Param{T}
Convenience method for creating a `Param` object for a given array.
"""
param(x) = Param(x, zero(x))
state(p::Param) = p.x
"""
accumulate!(p::Param, Δ) => p
Accumulates the update `Δ` on `p`. The value of `p` won't change until
`update!`.
"""
function accumulate!(p::Param, Δ)
p.Δx += Δ
return p
end
"""
update!(p::Param)
Apply the accumulated updates to the value of the parameter.
"""
function update!(p::Param, η)
p.x .-= p.Δx .* η
p.Δx[:] = 0
return p
end
state(x) = x
accumulate!(x, Δ) = x
Base.size(p::Param) = size(p.x)
Base.size(p::Param, n) = size(p.x, n)
function Base.show(io::IO, p::Param)
print(io, "Param", size(p.x))
end
Base.copy!(xs, p::Param) = copy!(xs, p.x)
Base.copy!(p::Param, xs) = copy!(p.x, xs)
# Anonymous models
export Capacitor
struct Capacitor <: Model
graph::IVertex{Any}
end
(m::Capacitor)(xs...) = interpmodel(m, xs...)
graph(cap::Capacitor) = cap.graph
# Recurrent Models
mutable struct Stateful <: Model
model
istate::Vector{Any}
ostate::Vector{Any}
end
Stateful(model, state) = Stateful(model, state, state)
function (m::Stateful)(x)
m.istate = m.ostate
state, y = runmodel(m.model, (m.istate...,), x)
m.ostate = collect(state)
return y
end
function back!(m::Stateful, Δ, x)
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
end
stateless(m) = m
stateless(m::Stateful) = m.model
struct SeqModel <: Model
model
steps::Int
end
runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
function (m::SeqModel)(x)
runseq(x) do x
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
m.model(x)
end
end
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)