use param object rather than named input

This commit is contained in:
Mike J Innes 2016-10-25 17:57:20 +01:00
parent ee0c5ae14e
commit dea85df8b7
4 changed files with 11 additions and 10 deletions

View File

@ -20,8 +20,7 @@ function graph(model::Model, args...)
g = Flux.graph(model)
g nothing || error("No graph for $model")
g = Flow.mapconst(g) do x
!isa(x, Flux.ModelInput) ? x :
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
isa(x, Flux.ModelInput) ? args[x.n] : x
end
postwalk(g) do v
vertex(graph(cvalue(v), cvalue.(inputs(v))...))

View File

@ -14,9 +14,7 @@ end
function makegraph(graph, args)
@assert length(args) == 1
mapconst(graph) do x
x == args[1] ? ModelInput(1) :
@capture(x, self.p_) ? ModelInput(p) :
x
x == args[1] ? ModelInput(1) : x
end
end
@ -74,7 +72,7 @@ function process_type(ex)
(self::$T)($(args...),) = $(build_forward(body, args))
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
graph(::$T) = $(Flow.constructor(makegraph(body, args)))
graph(self::$T) = $(Flow.constructor(makegraph(body, args)))
nothing
end |> esc
end

View File

@ -1,11 +1,11 @@
# TODO: change the input approach
immutable ModelInput
name
n::Int
end
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput) && isa(x.value.name, Integer)
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput)
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
bumpinput(i::ModelInput) = ModelInput(i.n + 1)
bumpinput(x) = x
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
@ -13,7 +13,7 @@ bumpinputs(v::IVertex) = mapconst(bumpinput, v)
function spliceinputs(v::IVertex, inputs::IVertex...)
postwalk(v) do v
isinput(value(v)) ?
inputs[value(v).value.name] :
inputs[value(v).value.n] :
v
end
end

View File

@ -36,6 +36,10 @@ accumulate!(x, Δ) = x
@forward Param.x Base.size
function Base.show(io::IO, p::Param)
print(io, "Param", size(p.x))
end
# Anonymous models
export Capacitor