small reorg

This commit is contained in:
Mike J Innes 2016-08-23 16:32:19 +01:00
parent b8565a4cc3
commit 2635283bf1
5 changed files with 53 additions and 47 deletions

View File

@ -4,13 +4,7 @@ using MacroTools, Lazy, Flow
# Zero Flux Given
export Model, back!, update!
abstract Model
back!(m::Model, ) = error("Backprop not implemented for $(typeof(m))")
update!(m::Model, η) = m
include("model.jl")
include("utils.jl")
include("compiler/diff.jl")
@ -18,8 +12,6 @@ include("compiler/code.jl")
include("cost.jl")
include("activation.jl")
include("layers/params.jl")
include("layers/anon.jl")
include("layers/input.jl")
include("layers/dense.jl")
include("layers/sequence.jl")

View File

@ -88,7 +88,7 @@ function process_anon(ex)
isexpr(value(v), Symbol) && push!(layers, value(v))
end
@assert length(args) == 1
:(Capacitor(
:(Flux.Capacitor(
($(args...)) -> $(syntax(build_forward(body, args))),
(Δ, $(args...)) -> $(syntax(build_backward(body, args[1]))),
η -> $(map(p -> :(update!($p, η)), layers)...),

View File

@ -1,16 +0,0 @@
export Capacitor
type Capacitor <: Model
forward::Function
backward::Function
update::Function
graph::IVertex{Any}
end
(cap::Capacitor)(args...) = cap.forward(args...)
back!(cap::Capacitor, args...) = cap.backward(args...)
update!(cap::Capacitor, η) = cap.update(η)
graph(cap::Capacitor) = cap.graph

View File

@ -1,21 +0,0 @@
type Param{T}
x::T
Δx::T
end
param(x) = Param(x, zero(x))
state(p::Param) = p.x
function accumulate!(p::Param, Δ)
p.Δx .+= Δ
return p
end
function update!(p::Param, η)
p.x .+= p.Δx .* η
return p
end
state(x) = x
accumulate!(x, Δ) = x

51
src/model.jl Normal file
View File

@ -0,0 +1,51 @@
export Model, back!, update!, param
# Basic model API
abstract Model
back!(m::Model, ) = error("Backprop not implemented for $(typeof(m))")
update!(m::Model, η) = m
# Model parameters
type Param{T}
x::T
Δx::T
end
param(x) = Param(x, zero(x))
state(p::Param) = p.x
function accumulate!(p::Param, Δ)
p.Δx .+= Δ
return p
end
function update!(p::Param, η)
p.x .+= p.Δx .* η
return p
end
state(x) = x
accumulate!(x, Δ) = x
# Anonymous models
export Capacitor
type Capacitor <: Model
forward::Function
backward::Function
update::Function
graph::IVertex{Any}
end
(cap::Capacitor)(args...) = cap.forward(args...)
back!(cap::Capacitor, args...) = cap.backward(args...)
update!(cap::Capacitor, η) = cap.update(η)
graph(cap::Capacitor) = cap.graph