From b145b46cbbe78bffd0a5048c7e8087b3711e5909 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 30 May 2017 16:37:44 +0100 Subject: [PATCH] fix for recurrent anon models --- src/compiler/code.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index ac0980fe..772dc01e 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -10,7 +10,7 @@ function graphdef(ex, params = []) return args, body end -function makegraph(graph, args) +function makegraph(graph, args, params = []) graph = prewalk(graph) do v value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ? inputnode(i) : @@ -18,7 +18,8 @@ function makegraph(graph, args) end graph = map(graph) do x x isa Offset ? - :(Flux.Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) : + :(Flux.Offset($(Expr(:quote, x.name)), $(x.n), + $(x.name in params ? :(self.$(x.name)) : x.name))) : x end vertex(:(Flux.Frame(self)), graph) @@ -68,7 +69,7 @@ function process_type(ex) $(build_type(T, params)) $(esc(:((self::$T)($(args...)) = $(build_forward(body, args))))) $(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);) - $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args)))) + $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args, params)))) nothing end end