new params approach
This commit is contained in:
parent
4ad4e20c3e
commit
8e92403436
|
@ -16,6 +16,7 @@ include("compiler/code.jl")
|
|||
|
||||
include("cost.jl")
|
||||
include("activation.jl")
|
||||
include("layers/params.jl")
|
||||
include("layers/input.jl")
|
||||
include("layers/dense.jl")
|
||||
include("layers/sequence.jl")
|
||||
|
|
|
@ -24,15 +24,19 @@ function build_type(T, params)
|
|||
quote
|
||||
type $T <: Model
|
||||
$(params...)
|
||||
$([Symbol("Δ", s) for s in params]...)
|
||||
end
|
||||
$T($(params...)) = $T($(params...),
|
||||
$((:(zeros($p)) for p in params)...))
|
||||
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
|
||||
end
|
||||
end
|
||||
|
||||
function deref_params(v)
|
||||
mapconst(v) do x
|
||||
@capture(x, self.p_) ? :(state(self.$p)) : x
|
||||
end
|
||||
end
|
||||
|
||||
function build_forward(body, args)
|
||||
cse(body)
|
||||
cse(deref_params(body))
|
||||
end
|
||||
|
||||
function build_backward(body, x, params)
|
||||
|
@ -40,29 +44,20 @@ function build_backward(body, x, params)
|
|||
back = IVertex{Any}(Flow.Do())
|
||||
for param in params
|
||||
haskey(Δs, :(self.$param)) || continue
|
||||
k = Symbol("Δ", param)
|
||||
ksym = Expr(:quote, k)
|
||||
ex = Δs[:(self.$param)]
|
||||
thread!(back, @dvertex(setfield!(:self, ksym, :(self.$k) + ex)))
|
||||
ex = deref_params(ex)
|
||||
thread!(back, @dvertex(accumulate!(:(self.$param), ex)))
|
||||
end
|
||||
ex = Δs[x]
|
||||
ex = deref_params(ex)
|
||||
thread!(back, @flow(tuple($ex)))
|
||||
cse(back)
|
||||
end
|
||||
|
||||
function build_update(T, params)
|
||||
updates = []
|
||||
for p in params
|
||||
Δp = Symbol("Δ", p)
|
||||
push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0)))
|
||||
end
|
||||
:(update!(self::$T) = $(updates...))
|
||||
end
|
||||
|
||||
function process_type(ex)
|
||||
@capture(ex, type T_ fs__ end)
|
||||
@destruct [params = true || [],
|
||||
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
||||
@destruct [params = false || [],
|
||||
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
||||
@assert length(funcs) == 1
|
||||
args, body = process_func(funcs[1], params)
|
||||
@assert length(args) == 1
|
||||
|
@ -70,7 +65,7 @@ function process_type(ex)
|
|||
$(build_type(T, params))
|
||||
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
|
||||
$(build_update(T, params))
|
||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...)
|
||||
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||
nothing
|
||||
end |> esc
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
type Param{T}
|
||||
x::T
|
||||
Δx::T
|
||||
end
|
||||
|
||||
param(x) = Param(x, zero(x))
|
||||
|
||||
state(p::Param) = p
|
||||
state(x) = x
|
||||
|
||||
function accumulate!(p::Param, Δ)
|
||||
p.Δx += Δ
|
||||
return p
|
||||
end
|
||||
|
||||
function update!(p::Param, η)
|
||||
p.x += p.Δx * η
|
||||
return p
|
||||
end
|
||||
|
||||
accumulate!(x, Δ) = x
|
Loading…
Reference in New Issue