handle state on julia side

This commit is contained in:
Mike J Innes 2016-10-28 21:17:48 +01:00
parent e450a585b7
commit d9ed5676c2
5 changed files with 36 additions and 17 deletions

View File

@ -1,9 +1,9 @@
type Model type Model
model model::Any
session::Session session::Session
params::Dict{Flux.Param,Tensor} params::Dict{Flux.Param,Tensor}
inputs::Vector{Tensor} inputs::Vector{Tensor}
output output::Any
gradients::Vector{Tensor} gradients::Vector{Tensor}
end end

View File

@ -1,17 +1,25 @@
immutable SeqModel # TODO: refactor, some of this is more general than just the TF backend
type SeqModel
m::Model m::Model
state::Any
end end
cgroup(xs...) = Flow.group(map(constant, xs)...)
function tf(model::Flux.Unrolled) function tf(model::Flux.Unrolled)
sess = Session(Graph()) sess = Session(Graph())
input = placeholder(Float32) input = placeholder(Float32)
instates = [placeholder(Float32) for _ in model.states]
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1) inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
params, (state, outputs) = tograph(model.graph, inputs...) params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
output = TensorFlow.pack(outputs, axis = 1) output = TensorFlow.pack(outputs, axis = 1)
run(sess, initialize_all_variables()) run(sess, initialize_all_variables())
Model(model, sess, params, SeqModel(
[input], [output], Model(model, sess, params,
[gradients(output, input)]) |> SeqModel [instates..., input], [outstates..., output],
[gradients(output, input)]),
[])
end end
function batchseq(xs) function batchseq(xs)
@ -22,6 +30,13 @@ function batchseq(xs)
Batch{Seq{T,S},B}(xs) Batch{Seq{T,S},B}(xs)
end end
(m::SeqModel)(x::BatchSeq) = batchseq(rawbatch(m.m(x))) function (m::SeqModel)(x::BatchSeq)
if isempty(m.state) || length(first(m.state)) length(x)
m.state = map(batchone, m.m.model.states)
end
output = m.m(m.state..., x)
m.state, output = output[1:end-1], output[end]
return batchseq(rawbatch(output))
end
(m::SeqModel)(x::Seq) = first(m(batchone(x))) (m::SeqModel)(x::Seq) = first(m(batchone(x)))

View File

@ -7,5 +7,6 @@ export tf
include("graph.jl") include("graph.jl")
include("model.jl") include("model.jl")
include("recurrent.jl")
end end

View File

@ -1,6 +1,8 @@
immutable ModelInput end immutable ModelInput end
inputnode(n) = vertex(Split(n), constant(ModelInput())) splitnode(v, n) = vertex(Split(n), v)
inputnode(n) = splitnode(constant(ModelInput()), n)
function bumpinputs(v::IVertex) function bumpinputs(v::IVertex)
prewalk(v) do v prewalk(v) do v

View File

@ -63,25 +63,26 @@ end
function unrollgraph(model, n) function unrollgraph(model, n)
graph, defaults = break!(atomise(model)) graph, defaults = break!(atomise(model))
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))] outputs = [spliceinputs(graph, group([constant(splitnode(inputnode(1),i)) for i = 1:length(defaults)]...),
detuple(outputs[end]) splitnode(inputnode(2), 1))]
for i = 2:n for i = 2:n
push!(outputs, spliceinputs(graph, outputs[end][1], inputnode(i))) push!(outputs, spliceinputs(graph, outputs[end][1], splitnode(inputnode(2), i)))
end end
state = outputs[end][1] state = outputs[end][1]
outputs = map(x -> x[2], outputs) outputs = map(x -> x[2], outputs)
@> group(state, group(outputs...)) detuple (@> group(state, group(outputs...)) detuple), map(x->x.x, defaults)
end end
type Unrolled <: Model type Unrolled <: Model
model model
graph::IVertex{Any} graph::IVertex{Any}
states::Vector{Any}
steps::Int steps::Int
end end
graph(u::Unrolled) = u.graph graph(u::Unrolled) = u.graph
unroll(model, n) = Unrolled(model, unrollgraph(model, n), n) unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
@net type Recurrent @net type Recurrent
Wxh; Whh; Why Wxh; Whh; Why
@ -95,9 +96,9 @@ end
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) = Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)), Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
initn(hidden), initn(out), zeros(Float32, hidden)') initn(hidden), initn(out), zeros(Float32, hidden))
# syntax(x) = syntax(Flow.dl(x), bindconst = true) # syntax(x) = syntax(Flow.dl(x), bindconst = true)
# r = Recurrent(10, 30, 20) # r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10))
# unrollgraph(r,5) |> cse |> syntax |> prettify |> display # unrollgraph(r,5)[1] |> syntax |> prettify |> clipboard