support for inner layers
This commit is contained in:
parent
8e92403436
commit
cf2b168a55
@ -1,3 +1,5 @@
|
|||||||
|
# TODO: proper escaping
|
||||||
|
|
||||||
import Flow: mapconst, cse
|
import Flow: mapconst, cse
|
||||||
|
|
||||||
function process_func(ex, params)
|
function process_func(ex, params)
|
||||||
@ -21,17 +23,22 @@ function makegraph(graph, args)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function build_type(T, params)
|
function build_type(T, params)
|
||||||
quote
|
ex = quote
|
||||||
type $T <: Model
|
type $T <: Model
|
||||||
$(params...)
|
$(params...)
|
||||||
end
|
end
|
||||||
$T($(map(x->:($x::AArray), params)...)) = $T($(map(x->:(param($x)), params)...))
|
|
||||||
end
|
end
|
||||||
|
if any(x->isexpr(x, Symbol), params)
|
||||||
|
push!(ex.args,
|
||||||
|
:($T($(map(x->isexpr(x, Symbol) ? :($x::AArray) : x, params)...)) =
|
||||||
|
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
|
||||||
|
end
|
||||||
|
ex
|
||||||
end
|
end
|
||||||
|
|
||||||
function deref_params(v)
|
function deref_params(v)
|
||||||
mapconst(v) do x
|
map(v) do x
|
||||||
@capture(x, self.p_) ? :(state(self.$p)) : x
|
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(state(self.$p))) : x
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -59,13 +66,14 @@ function process_type(ex)
|
|||||||
@destruct [params = false || [],
|
@destruct [params = false || [],
|
||||||
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
||||||
@assert length(funcs) == 1
|
@assert length(funcs) == 1
|
||||||
args, body = process_func(funcs[1], params)
|
pnames = namify.(params)
|
||||||
|
args, body = process_func(funcs[1], pnames)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
||||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
|
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], pnames)))
|
||||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), params)...)
|
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||||
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||||
nothing
|
nothing
|
||||||
end |> esc
|
end |> esc
|
||||||
|
@ -8,3 +8,11 @@ end
|
|||||||
|
|
||||||
Dense(in::Integer, out::Integer; init = randn) =
|
Dense(in::Integer, out::Integer; init = randn) =
|
||||||
Dense(init(out, in), init(out))
|
Dense(init(out, in), init(out))
|
||||||
|
|
||||||
|
@model type Sigmoid
|
||||||
|
layer::Model
|
||||||
|
x -> σ(layer(x))
|
||||||
|
end
|
||||||
|
|
||||||
|
Sigmoid(in::Integer, out::Integer; init = randn) =
|
||||||
|
Sigmoid(Dense(in, out, init = init))
|
||||||
|
Loading…
Reference in New Issue
Block a user