rename ModelInput
This commit is contained in:
parent
545d4480ed
commit
dcdc5fd9c3
|
@ -17,7 +17,7 @@ end
|
|||
function graph(vars, model::Model, args...)
|
||||
g = Flux.graph(model)
|
||||
g = Flow.mapconst(g) do x
|
||||
!isa(x, Flux.Parameter) ? x :
|
||||
!isa(x, Flux.ModelInput) ? x :
|
||||
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
||||
end
|
||||
postwalk(g) do v
|
||||
|
|
|
@ -9,15 +9,15 @@ function process_func(ex, params)
|
|||
return args, body
|
||||
end
|
||||
|
||||
immutable Parameter
|
||||
immutable ModelInput
|
||||
name
|
||||
end
|
||||
|
||||
function makegraph(graph, args)
|
||||
@assert length(args) == 1
|
||||
mapconst(graph) do x
|
||||
x == args[1] ? Parameter(1) :
|
||||
@capture(x, self.p_) ? Parameter(p) :
|
||||
x == args[1] ? ModelInput(1) :
|
||||
@capture(x, self.p_) ? ModelInput(p) :
|
||||
x
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue