2017-05-22 15:18:41 +00:00
|
|
|
import DataFlow: cse
|
2017-03-09 01:10:44 +00:00
|
|
|
using MacroTools: @q
|
2016-08-18 21:06:12 +00:00
|
|
|
|
2017-04-18 18:02:55 +00:00
|
|
|
export @net
|
2016-08-23 15:22:11 +00:00
|
|
|
|
2017-05-03 17:33:14 +00:00
|
|
|
function graphdef(ex, params = [])
|
2016-04-01 21:11:42 +00:00
|
|
|
@capture(shortdef(ex), (args__,) -> body_)
|
2016-12-19 15:04:41 +00:00
|
|
|
body = @> body MacroTools.flatten liftloops graphm DataFlow.il
|
2017-05-22 15:18:41 +00:00
|
|
|
body = map(x -> x in params ? :(self.$x) : x, body)
|
2016-04-01 21:11:42 +00:00
|
|
|
return args, body
|
|
|
|
end
|
|
|
|
|
2017-05-30 15:37:44 +00:00
|
|
|
function makegraph(graph, args, params = [])
|
2016-11-14 21:56:40 +00:00
|
|
|
graph = prewalk(graph) do v
|
2017-05-22 16:39:08 +00:00
|
|
|
isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ?
|
2017-03-30 14:54:42 +00:00
|
|
|
inputnode(i) :
|
2016-11-14 21:56:40 +00:00
|
|
|
v
|
|
|
|
end
|
|
|
|
graph = map(graph) do x
|
2017-03-14 15:21:18 +00:00
|
|
|
x isa Offset ?
|
2017-05-30 15:37:44 +00:00
|
|
|
:(Flux.Offset($(Expr(:quote, x.name)), $(x.n),
|
|
|
|
$(x.name in params ? :(self.$(x.name)) : x.name))) :
|
2016-11-14 21:56:40 +00:00
|
|
|
x
|
2016-08-22 13:49:12 +00:00
|
|
|
end
|
2016-12-26 12:11:24 +00:00
|
|
|
vertex(:(Flux.Frame(self)), graph)
|
2016-08-22 13:49:12 +00:00
|
|
|
end
|
|
|
|
|
2016-04-01 21:11:42 +00:00
|
|
|
function build_type(T, params)
|
2016-11-14 21:56:40 +00:00
|
|
|
@esc T
|
2016-08-23 13:14:20 +00:00
|
|
|
ex = quote
|
2016-08-22 13:49:23 +00:00
|
|
|
type $T <: Model
|
2016-04-01 21:11:42 +00:00
|
|
|
$(params...)
|
|
|
|
end
|
2016-08-23 11:38:58 +00:00
|
|
|
end
|
2016-08-23 13:14:20 +00:00
|
|
|
if any(x->isexpr(x, Symbol), params)
|
|
|
|
push!(ex.args,
|
|
|
|
:($T($(map(x->isexpr(x, Symbol) ? :($x::AArray) : x, params)...)) =
|
|
|
|
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
|
|
|
|
end
|
|
|
|
ex
|
2016-08-23 11:38:58 +00:00
|
|
|
end
|
|
|
|
|
2017-02-24 14:38:17 +00:00
|
|
|
function deref_params(v)
|
2017-05-04 09:32:53 +00:00
|
|
|
map(v) do x
|
2017-05-22 16:39:08 +00:00
|
|
|
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
|
2017-02-24 14:38:17 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
function build_forward(body, args)
|
|
|
|
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
|
|
|
applylines(syntax(cse(deref_params(body))))
|
|
|
|
end
|
|
|
|
|
2016-10-29 22:36:39 +00:00
|
|
|
import Lazy: groupby
|
|
|
|
|
2017-05-22 15:18:41 +00:00
|
|
|
reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v)
|
2016-11-13 20:46:35 +00:00
|
|
|
|
2017-03-09 00:12:49 +00:00
|
|
|
# TODO: type hints for parameters
|
|
|
|
|
2016-04-01 21:11:42 +00:00
|
|
|
function process_type(ex)
|
|
|
|
@capture(ex, type T_ fs__ end)
|
2016-08-23 11:38:58 +00:00
|
|
|
@destruct [params = false || [],
|
|
|
|
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
2016-04-01 21:11:42 +00:00
|
|
|
@assert length(funcs) == 1
|
2016-08-23 13:14:20 +00:00
|
|
|
pnames = namify.(params)
|
2017-05-03 17:33:14 +00:00
|
|
|
args, body = graphdef(funcs[1], pnames)
|
2016-12-21 13:04:54 +00:00
|
|
|
self = esc(:self)
|
2017-02-24 15:48:52 +00:00
|
|
|
quote
|
2016-04-01 21:11:42 +00:00
|
|
|
$(build_type(T, params))
|
2017-04-18 20:04:21 +00:00
|
|
|
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
2016-12-21 13:04:54 +00:00
|
|
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
2017-05-22 15:18:41 +00:00
|
|
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
2016-08-22 13:49:34 +00:00
|
|
|
nothing
|
2016-11-14 21:56:40 +00:00
|
|
|
end
|
2016-04-01 21:11:42 +00:00
|
|
|
end
|
|
|
|
|
2016-08-23 15:22:11 +00:00
|
|
|
function process_anon(ex)
|
2017-05-03 17:33:14 +00:00
|
|
|
args, body = graphdef(ex)
|
2017-05-22 15:18:41 +00:00
|
|
|
:(Capacitor($(DataFlow.constructor(map(esc, makegraph(body, args)[1])))))
|
2016-08-23 15:22:11 +00:00
|
|
|
end
|
|
|
|
|
2017-05-03 17:33:23 +00:00
|
|
|
function process_def(ex)
|
|
|
|
# TODO: make a singleton net type
|
|
|
|
@capture(ex, f_(xs__) = body_)
|
2017-05-03 18:13:33 +00:00
|
|
|
:($(esc(f)) = @net $(esc(:(($(xs...),) -> $body))); nothing)
|
2017-05-03 17:33:23 +00:00
|
|
|
end
|
|
|
|
|
2017-03-20 19:57:00 +00:00
|
|
|
macro net(ex)
|
|
|
|
ex = shortdef(ex)
|
|
|
|
isexpr(ex, :type) ? process_type(ex) :
|
|
|
|
@capture(ex, (__,) -> _) ? process_anon(ex) :
|
2017-05-03 17:33:23 +00:00
|
|
|
@capture(ex, _(__) = _) ? process_def(ex) :
|
2017-03-20 19:57:00 +00:00
|
|
|
error("Unsupported model expression $ex")
|
2016-08-22 13:49:34 +00:00
|
|
|
end
|