rename ModelInput
This commit is contained in:
parent
545d4480ed
commit
dcdc5fd9c3
@ -17,7 +17,7 @@ end
|
|||||||
function graph(vars, model::Model, args...)
|
function graph(vars, model::Model, args...)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g = Flow.mapconst(g) do x
|
g = Flow.mapconst(g) do x
|
||||||
!isa(x, Flux.Parameter) ? x :
|
!isa(x, Flux.ModelInput) ? x :
|
||||||
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
||||||
end
|
end
|
||||||
postwalk(g) do v
|
postwalk(g) do v
|
||||||
|
@ -9,15 +9,15 @@ function process_func(ex, params)
|
|||||||
return args, body
|
return args, body
|
||||||
end
|
end
|
||||||
|
|
||||||
immutable Parameter
|
immutable ModelInput
|
||||||
name
|
name
|
||||||
end
|
end
|
||||||
|
|
||||||
function makegraph(graph, args)
|
function makegraph(graph, args)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
mapconst(graph) do x
|
mapconst(graph) do x
|
||||||
x == args[1] ? Parameter(1) :
|
x == args[1] ? ModelInput(1) :
|
||||||
@capture(x, self.p_) ? Parameter(p) :
|
@capture(x, self.p_) ? ModelInput(p) :
|
||||||
x
|
x
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user