use param object rather than named input
This commit is contained in:
parent
ee0c5ae14e
commit
dea85df8b7
@ -20,8 +20,7 @@ function graph(model::Model, args...)
|
|||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g ≠ nothing || error("No graph for $model")
|
g ≠ nothing || error("No graph for $model")
|
||||||
g = Flow.mapconst(g) do x
|
g = Flow.mapconst(g) do x
|
||||||
!isa(x, Flux.ModelInput) ? x :
|
isa(x, Flux.ModelInput) ? args[x.n] : x
|
||||||
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
|
||||||
end
|
end
|
||||||
postwalk(g) do v
|
postwalk(g) do v
|
||||||
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
|
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
|
||||||
|
@ -14,9 +14,7 @@ end
|
|||||||
function makegraph(graph, args)
|
function makegraph(graph, args)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
mapconst(graph) do x
|
mapconst(graph) do x
|
||||||
x == args[1] ? ModelInput(1) :
|
x == args[1] ? ModelInput(1) : x
|
||||||
@capture(x, self.p_) ? ModelInput(p) :
|
|
||||||
x
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -74,7 +72,7 @@ function process_type(ex)
|
|||||||
(self::$T)($(args...),) = $(build_forward(body, args))
|
(self::$T)($(args...),) = $(build_forward(body, args))
|
||||||
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
||||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||||
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
|
graph(self::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||||
nothing
|
nothing
|
||||||
end |> esc
|
end |> esc
|
||||||
end
|
end
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
# TODO: change the input approach
|
# TODO: change the input approach
|
||||||
immutable ModelInput
|
immutable ModelInput
|
||||||
name
|
n::Int
|
||||||
end
|
end
|
||||||
|
|
||||||
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput) && isa(x.value.name, Integer)
|
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput)
|
||||||
|
|
||||||
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
|
bumpinput(i::ModelInput) = ModelInput(i.n + 1)
|
||||||
bumpinput(x) = x
|
bumpinput(x) = x
|
||||||
|
|
||||||
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
||||||
@ -13,7 +13,7 @@ bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
|||||||
function spliceinputs(v::IVertex, inputs::IVertex...)
|
function spliceinputs(v::IVertex, inputs::IVertex...)
|
||||||
postwalk(v) do v
|
postwalk(v) do v
|
||||||
isinput(value(v)) ?
|
isinput(value(v)) ?
|
||||||
inputs[value(v).value.name] :
|
inputs[value(v).value.n] :
|
||||||
v
|
v
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -36,6 +36,10 @@ accumulate!(x, Δ) = x
|
|||||||
|
|
||||||
@forward Param.x Base.size
|
@forward Param.x Base.size
|
||||||
|
|
||||||
|
function Base.show(io::IO, p::Param)
|
||||||
|
print(io, "Param", size(p.x))
|
||||||
|
end
|
||||||
|
|
||||||
# Anonymous models
|
# Anonymous models
|
||||||
|
|
||||||
export Capacitor
|
export Capacitor
|
||||||
|
Loading…
Reference in New Issue
Block a user