handle recurrence

This commit is contained in:
Mike J Innes 2016-08-26 17:27:46 +01:00
parent c55f955f1e
commit edf69ac968

View File

@ -6,7 +6,7 @@ export @model
function process_func(ex, params = []) function process_func(ex, params = [])
@capture(shortdef(ex), (args__,) -> body_) @capture(shortdef(ex), (args__,) -> body_)
body = Flow.il(graphm(body)) body = Flow.il(graphm(unblock(body)))
body = mapconst(x -> x in params ? :(self.$x) : x, body) body = mapconst(x -> x in params ? :(self.$x) : x, body)
return args, body return args, body
end end
@ -45,10 +45,12 @@ function deref_params(v)
end end
function build_forward(body, args) function build_forward(body, args)
cse(deref_params(body)) iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
syntax(cse(deref_params(body)))
end end
function build_backward(body, x, params = []) function build_backward(body, x, params = [])
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
Δs = invert(body) Δs = invert(body)
back = IVertex{Any}(Flow.Do()) back = IVertex{Any}(Flow.Do())
for param in params for param in params
@ -60,7 +62,7 @@ function build_backward(body, x, params = [])
ex = Δs[x] ex = Δs[x]
ex = deref_params(ex) ex = deref_params(ex)
thread!(back, @flow(tuple($ex))) thread!(back, @flow(tuple($ex)))
cse(back) syntax(cse(back))
end end
function process_type(ex) function process_type(ex)
@ -73,8 +75,8 @@ function process_type(ex)
@assert length(args) == 1 @assert length(args) == 1
quote quote
$(build_type(T, params)) $(build_type(T, params))
(self::$T)($(args...),) = $(syntax(build_forward(body, args))) (self::$T)($(args...),) = $(build_forward(body, args))
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], pnames))) back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...) update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
graph(::$T) = $(Flow.constructor(makegraph(body, args))) graph(::$T) = $(Flow.constructor(makegraph(body, args)))
nothing nothing
@ -89,8 +91,8 @@ function process_anon(ex)
end end
@assert length(args) == 1 @assert length(args) == 1
:(Flux.Capacitor( :(Flux.Capacitor(
($(args...)) -> $(syntax(build_forward(body, args))), ($(args...)) -> $(build_forward(body, args)),
(Δ, $(args...)) -> $(syntax(build_backward(body, args[1]))), (Δ, $(args...)) -> $(build_backward(body, args[1])),
η -> $(map(p -> :(update!($p, η)), layers)...), η -> $(map(p -> :(update!($p, η)), layers)...),
$(Flow.constructor(makegraph(body, args))))) |> esc $(Flow.constructor(makegraph(body, args))))) |> esc
end end