fix for recurrent anon models
This commit is contained in:
parent
e37973c3d5
commit
b145b46cbb
@ -10,7 +10,7 @@ function graphdef(ex, params = [])
|
|||||||
return args, body
|
return args, body
|
||||||
end
|
end
|
||||||
|
|
||||||
function makegraph(graph, args)
|
function makegraph(graph, args, params = [])
|
||||||
graph = prewalk(graph) do v
|
graph = prewalk(graph) do v
|
||||||
value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ?
|
value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ?
|
||||||
inputnode(i) :
|
inputnode(i) :
|
||||||
@ -18,7 +18,8 @@ function makegraph(graph, args)
|
|||||||
end
|
end
|
||||||
graph = map(graph) do x
|
graph = map(graph) do x
|
||||||
x isa Offset ?
|
x isa Offset ?
|
||||||
:(Flux.Offset($(Expr(:quote, x.name)), $(x.n), self.$(x.name))) :
|
:(Flux.Offset($(Expr(:quote, x.name)), $(x.n),
|
||||||
|
$(x.name in params ? :(self.$(x.name)) : x.name))) :
|
||||||
x
|
x
|
||||||
end
|
end
|
||||||
vertex(:(Flux.Frame(self)), graph)
|
vertex(:(Flux.Frame(self)), graph)
|
||||||
@ -68,7 +69,7 @@ function process_type(ex)
|
|||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args, params))))
|
||||||
nothing
|
nothing
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user