use param object rather than named input
This commit is contained in:
parent
ee0c5ae14e
commit
dea85df8b7
|
@ -20,8 +20,7 @@ function graph(model::Model, args...)
|
|||
g = Flux.graph(model)
|
||||
g ≠ nothing || error("No graph for $model")
|
||||
g = Flow.mapconst(g) do x
|
||||
!isa(x, Flux.ModelInput) ? x :
|
||||
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
||||
isa(x, Flux.ModelInput) ? args[x.n] : x
|
||||
end
|
||||
postwalk(g) do v
|
||||
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
|
||||
|
|
|
@ -14,9 +14,7 @@ end
|
|||
function makegraph(graph, args)
|
||||
@assert length(args) == 1
|
||||
mapconst(graph) do x
|
||||
x == args[1] ? ModelInput(1) :
|
||||
@capture(x, self.p_) ? ModelInput(p) :
|
||||
x
|
||||
x == args[1] ? ModelInput(1) : x
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -74,7 +72,7 @@ function process_type(ex)
|
|||
(self::$T)($(args...),) = $(build_forward(body, args))
|
||||
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||
graph(self::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||
nothing
|
||||
end |> esc
|
||||
end
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# TODO: change the input approach
|
||||
immutable ModelInput
|
||||
name
|
||||
n::Int
|
||||
end
|
||||
|
||||
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput) && isa(x.value.name, Integer)
|
||||
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput)
|
||||
|
||||
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
|
||||
bumpinput(i::ModelInput) = ModelInput(i.n + 1)
|
||||
bumpinput(x) = x
|
||||
|
||||
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
||||
|
@ -13,7 +13,7 @@ bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
|||
function spliceinputs(v::IVertex, inputs::IVertex...)
|
||||
postwalk(v) do v
|
||||
isinput(value(v)) ?
|
||||
inputs[value(v).value.name] :
|
||||
inputs[value(v).value.n] :
|
||||
v
|
||||
end
|
||||
end
|
||||
|
|
|
@ -36,6 +36,10 @@ accumulate!(x, Δ) = x
|
|||
|
||||
@forward Param.x Base.size
|
||||
|
||||
function Base.show(io::IO, p::Param)
|
||||
print(io, "Param", size(p.x))
|
||||
end
|
||||
|
||||
# Anonymous models
|
||||
|
||||
export Capacitor
|
||||
|
|
Loading…
Reference in New Issue