support for inner layers

This commit is contained in:
Mike J Innes 2016-08-23 14:14:20 +01:00
parent 8e92403436
commit cf2b168a55
2 changed files with 23 additions and 7 deletions

View File

@ -1,3 +1,5 @@
# TODO: proper escaping
import Flow: mapconst, cse import Flow: mapconst, cse
function process_func(ex, params) function process_func(ex, params)
@ -21,17 +23,22 @@ function makegraph(graph, args)
end end
function build_type(T, params) function build_type(T, params)
quote ex = quote
type $T <: Model type $T <: Model
$(params...) $(params...)
end end
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
end end
if any(x->isexpr(x, Symbol), params)
push!(ex.args,
:($T($(map(x->isexpr(x, Symbol) ? :($x::AArray) : x, params)...)) =
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
end
ex
end end
function deref_params(v) function deref_params(v)
mapconst(v) do x map(v) do x
@capture(x, self.p_) ? :(state(self.$p)) : x isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(state(self.$p))) : x
end end
end end
@ -59,13 +66,14 @@ function process_type(ex)
@destruct [params = false || [], @destruct [params = false || [],
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs) funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
@assert length(funcs) == 1 @assert length(funcs) == 1
args, body = process_func(funcs[1], params) pnames = namify.(params)
args, body = process_func(funcs[1], pnames)
@assert length(args) == 1 @assert length(args) == 1
quote quote
$(build_type(T, params)) $(build_type(T, params))
(self::$T)($(args...),) = $(syntax(build_forward(body, args))) (self::$T)($(args...),) = $(syntax(build_forward(body, args)))
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params))) back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], pnames)))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...) update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
graph(::$T) = $(Flow.constructor(makegraph(body, args))) graph(::$T) = $(Flow.constructor(makegraph(body, args)))
nothing nothing
end |> esc end |> esc

View File

@ -8,3 +8,11 @@ end
Dense(in::Integer, out::Integer; init = randn) = Dense(in::Integer, out::Integer; init = randn) =
Dense(init(out, in), init(out)) Dense(init(out, in), init(out))
@model type Sigmoid
layer::Model
x -> σ(layer(x))
end
Sigmoid(in::Integer, out::Integer; init = randn) =
Sigmoid(Dense(in, out, init = init))