support for inner layers

This commit is contained in:
Mike J Innes 2016-08-23 14:14:20 +01:00
parent 8e92403436
commit cf2b168a55
2 changed files with 23 additions and 7 deletions

View File

@ -1,3 +1,5 @@
# TODO: proper escaping
import Flow: mapconst, cse
function process_func(ex, params)
@ -21,17 +23,22 @@ function makegraph(graph, args)
end
function build_type(T, params)
quote
ex = quote
type $T <: Model
$(params...)
end
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
end
if any(x->isexpr(x, Symbol), params)
push!(ex.args,
:($T($(map(x->isexpr(x, Symbol) ? :($x::AArray) : x, params)...)) =
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
end
ex
end
function deref_params(v)
mapconst(v) do x
@capture(x, self.p_) ? :(state(self.$p)) : x
map(v) do x
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(state(self.$p))) : x
end
end
@ -59,13 +66,14 @@ function process_type(ex)
@destruct [params = false || [],
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
@assert length(funcs) == 1
args, body = process_func(funcs[1], params)
pnames = namify.(params)
args, body = process_func(funcs[1], pnames)
@assert length(args) == 1
quote
$(build_type(T, params))
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...)
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], pnames)))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
nothing
end |> esc

View File

@ -8,3 +8,11 @@ end
Dense(in::Integer, out::Integer; init = randn) =
Dense(init(out, in), init(out))
@model type Sigmoid
layer::Model
x -> σ(layer(x))
end
Sigmoid(in::Integer, out::Integer; init = randn) =
Sigmoid(Dense(in, out, init = init))