handle recurrence
This commit is contained in:
parent
c55f955f1e
commit
edf69ac968
@ -6,7 +6,7 @@ export @model
|
|||||||
|
|
||||||
function process_func(ex, params = [])
|
function process_func(ex, params = [])
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
body = Flow.il(graphm(body))
|
body = Flow.il(graphm(unblock(body)))
|
||||||
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
||||||
return args, body
|
return args, body
|
||||||
end
|
end
|
||||||
@ -45,10 +45,12 @@ function deref_params(v)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function build_forward(body, args)
|
function build_forward(body, args)
|
||||||
cse(deref_params(body))
|
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
||||||
|
syntax(cse(deref_params(body)))
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_backward(body, x, params = [])
|
function build_backward(body, x, params = [])
|
||||||
|
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
|
||||||
Δs = invert(body)
|
Δs = invert(body)
|
||||||
back = IVertex{Any}(Flow.Do())
|
back = IVertex{Any}(Flow.Do())
|
||||||
for param in params
|
for param in params
|
||||||
@ -60,7 +62,7 @@ function build_backward(body, x, params = [])
|
|||||||
ex = Δs[x]
|
ex = Δs[x]
|
||||||
ex = deref_params(ex)
|
ex = deref_params(ex)
|
||||||
thread!(back, @flow(tuple($ex)))
|
thread!(back, @flow(tuple($ex)))
|
||||||
cse(back)
|
syntax(cse(back))
|
||||||
end
|
end
|
||||||
|
|
||||||
function process_type(ex)
|
function process_type(ex)
|
||||||
@ -73,8 +75,8 @@ function process_type(ex)
|
|||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
(self::$T)($(args...),) = $(build_forward(body, args))
|
||||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], pnames)))
|
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
||||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||||
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||||
nothing
|
nothing
|
||||||
@ -89,8 +91,8 @@ function process_anon(ex)
|
|||||||
end
|
end
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
:(Flux.Capacitor(
|
:(Flux.Capacitor(
|
||||||
($(args...)) -> $(syntax(build_forward(body, args))),
|
($(args...)) -> $(build_forward(body, args)),
|
||||||
(Δ, $(args...)) -> $(syntax(build_backward(body, args[1]))),
|
(Δ, $(args...)) -> $(build_backward(body, args[1])),
|
||||||
η -> $(map(p -> :(update!($p, η)), layers)...),
|
η -> $(map(p -> :(update!($p, η)), layers)...),
|
||||||
$(Flow.constructor(makegraph(body, args))))) |> esc
|
$(Flow.constructor(makegraph(body, args))))) |> esc
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user