support for inner layers
This commit is contained in:
parent
8e92403436
commit
cf2b168a55
|
@ -1,3 +1,5 @@
|
|||
# TODO: proper escaping
|
||||
|
||||
import Flow: mapconst, cse
|
||||
|
||||
function process_func(ex, params)
|
||||
|
@ -21,17 +23,22 @@ function makegraph(graph, args)
|
|||
end
|
||||
|
||||
function build_type(T, params)
|
||||
quote
|
||||
ex = quote
|
||||
type $T <: Model
|
||||
$(params...)
|
||||
end
|
||||
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
|
||||
end
|
||||
if any(x->isexpr(x, Symbol), params)
|
||||
push!(ex.args,
|
||||
:($T($(map(x->isexpr(x, Symbol) ? :($x::AArray) : x, params)...)) =
|
||||
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
|
||||
end
|
||||
ex
|
||||
end
|
||||
|
||||
function deref_params(v)
|
||||
mapconst(v) do x
|
||||
@capture(x, self.p_) ? :(state(self.$p)) : x
|
||||
map(v) do x
|
||||
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(state(self.$p))) : x
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -59,13 +66,14 @@ function process_type(ex)
|
|||
@destruct [params = false || [],
|
||||
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
||||
@assert length(funcs) == 1
|
||||
args, body = process_func(funcs[1], params)
|
||||
pnames = namify.(params)
|
||||
args, body = process_func(funcs[1], pnames)
|
||||
@assert length(args) == 1
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
|
||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...)
|
||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], pnames)))
|
||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||
nothing
|
||||
end |> esc
|
||||
|
|
|
@ -8,3 +8,11 @@ end
|
|||
|
||||
Dense(in::Integer, out::Integer; init = randn) =
|
||||
Dense(init(out, in), init(out))
|
||||
|
||||
@model type Sigmoid
|
||||
layer::Model
|
||||
x -> σ(layer(x))
|
||||
end
|
||||
|
||||
Sigmoid(in::Integer, out::Integer; init = randn) =
|
||||
Sigmoid(Dense(in, out, init = init))
|
||||
|
|
Loading…
Reference in New Issue