Flux.jl/src/model.jl

57 lines
883 B
Julia
Raw Normal View History

2016-08-23 15:32:19 +00:00
export Model, back!, update!, param
# Basic model API
abstract Model
back!(m::Model, ) = error("Backprop not implemented for $(typeof(m))")
2016-08-24 14:41:30 +00:00
update!(m, η) = m
2016-08-23 15:32:19 +00:00
2016-08-31 01:37:53 +00:00
graph(m) = nothing
2016-08-23 15:32:19 +00:00
# Model parameters
type Param{T}
x::T
Δx::T
end
param(x) = Param(x, zero(x))
state(p::Param) = p.x
function accumulate!(p::Param, Δ)
p.Δx .+= Δ
return p
end
function update!(p::Param, η)
2016-08-24 14:41:17 +00:00
p.x .-= p.Δx .* η
2016-08-23 22:56:31 +00:00
p.Δx[:] = 0
2016-08-23 15:32:19 +00:00
return p
end
state(x) = x
accumulate!(x, Δ) = x
# Anonymous models
export Capacitor
type Capacitor <: Model
forward::Function
backward::Function
update::Function
graph::IVertex{Any}
end
(cap::Capacitor)(args...) = cap.forward(args...)
back!(cap::Capacitor, args...) = cap.backward(args...)
update!(cap::Capacitor, η) = cap.update(η)
graph(cap::Capacitor) = cap.graph
2016-08-25 16:24:39 +00:00
Base.show(io::IO, ::Capacitor) = print(io, "Flux.Capacitor(...)")