new params approach
This commit is contained in:
parent
4ad4e20c3e
commit
8e92403436
@ -16,6 +16,7 @@ include("compiler/code.jl")
|
|||||||
|
|
||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
include("activation.jl")
|
include("activation.jl")
|
||||||
|
include("layers/params.jl")
|
||||||
include("layers/input.jl")
|
include("layers/input.jl")
|
||||||
include("layers/dense.jl")
|
include("layers/dense.jl")
|
||||||
include("layers/sequence.jl")
|
include("layers/sequence.jl")
|
||||||
|
@ -24,15 +24,19 @@ function build_type(T, params)
|
|||||||
quote
|
quote
|
||||||
type $T <: Model
|
type $T <: Model
|
||||||
$(params...)
|
$(params...)
|
||||||
$([Symbol("Δ", s) for s in params]...)
|
|
||||||
end
|
end
|
||||||
$T($(params...)) = $T($(params...),
|
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
|
||||||
$((:(zeros($p)) for p in params)...))
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function deref_params(v)
|
||||||
|
mapconst(v) do x
|
||||||
|
@capture(x, self.p_) ? :(state(self.$p)) : x
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_forward(body, args)
|
function build_forward(body, args)
|
||||||
cse(body)
|
cse(deref_params(body))
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_backward(body, x, params)
|
function build_backward(body, x, params)
|
||||||
@ -40,29 +44,20 @@ function build_backward(body, x, params)
|
|||||||
back = IVertex{Any}(Flow.Do())
|
back = IVertex{Any}(Flow.Do())
|
||||||
for param in params
|
for param in params
|
||||||
haskey(Δs, :(self.$param)) || continue
|
haskey(Δs, :(self.$param)) || continue
|
||||||
k = Symbol("Δ", param)
|
|
||||||
ksym = Expr(:quote, k)
|
|
||||||
ex = Δs[:(self.$param)]
|
ex = Δs[:(self.$param)]
|
||||||
thread!(back, @dvertex(setfield!(:self, ksym, :(self.$k) + ex)))
|
ex = deref_params(ex)
|
||||||
|
thread!(back, @dvertex(accumulate!(:(self.$param), ex)))
|
||||||
end
|
end
|
||||||
ex = Δs[x]
|
ex = Δs[x]
|
||||||
|
ex = deref_params(ex)
|
||||||
thread!(back, @flow(tuple($ex)))
|
thread!(back, @flow(tuple($ex)))
|
||||||
cse(back)
|
cse(back)
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_update(T, params)
|
|
||||||
updates = []
|
|
||||||
for p in params
|
|
||||||
Δp = Symbol("Δ", p)
|
|
||||||
push!(updates, :(self.$p += self.$Δp; fill!(self.$Δp, 0)))
|
|
||||||
end
|
|
||||||
:(update!(self::$T) = $(updates...))
|
|
||||||
end
|
|
||||||
|
|
||||||
function process_type(ex)
|
function process_type(ex)
|
||||||
@capture(ex, type T_ fs__ end)
|
@capture(ex, type T_ fs__ end)
|
||||||
@destruct [params = true || [],
|
@destruct [params = false || [],
|
||||||
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
||||||
@assert length(funcs) == 1
|
@assert length(funcs) == 1
|
||||||
args, body = process_func(funcs[1], params)
|
args, body = process_func(funcs[1], params)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
@ -70,7 +65,7 @@ function process_type(ex)
|
|||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
||||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
|
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
|
||||||
$(build_update(T, params))
|
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...)
|
||||||
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||||
nothing
|
nothing
|
||||||
end |> esc
|
end |> esc
|
||||||
|
21
src/layers/params.jl
Normal file
21
src/layers/params.jl
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
type Param{T}
|
||||||
|
x::T
|
||||||
|
Δx::T
|
||||||
|
end
|
||||||
|
|
||||||
|
param(x) = Param(x, zero(x))
|
||||||
|
|
||||||
|
state(p::Param) = p
|
||||||
|
state(x) = x
|
||||||
|
|
||||||
|
function accumulate!(p::Param, Δ)
|
||||||
|
p.Δx += Δ
|
||||||
|
return p
|
||||||
|
end
|
||||||
|
|
||||||
|
function update!(p::Param, η)
|
||||||
|
p.x += p.Δx * η
|
||||||
|
return p
|
||||||
|
end
|
||||||
|
|
||||||
|
accumulate!(x, Δ) = x
|
Loading…
Reference in New Issue
Block a user