handle state on julia side

This commit is contained in:
Mike J Innes 2016-10-28 21:17:48 +01:00
parent e450a585b7
commit d9ed5676c2
5 changed files with 36 additions and 17 deletions

View File

@ -1,9 +1,9 @@
type Model
model
model::Any
session::Session
params::Dict{Flux.Param,Tensor}
inputs::Vector{Tensor}
output
output::Any
gradients::Vector{Tensor}
end

View File

@ -1,17 +1,25 @@
immutable SeqModel
# TODO: refactor, some of this is more general than just the TF backend
type SeqModel
m::Model
state::Any
end
cgroup(xs...) = Flow.group(map(constant, xs)...)
function tf(model::Flux.Unrolled)
sess = Session(Graph())
input = placeholder(Float32)
instates = [placeholder(Float32) for _ in model.states]
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
params, (state, outputs) = tograph(model.graph, inputs...)
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
output = TensorFlow.pack(outputs, axis = 1)
run(sess, initialize_all_variables())
Model(model, sess, params,
[input], [output],
[gradients(output, input)]) |> SeqModel
SeqModel(
Model(model, sess, params,
[instates..., input], [outstates..., output],
[gradients(output, input)]),
[])
end
function batchseq(xs)
@ -22,6 +30,13 @@ function batchseq(xs)
Batch{Seq{T,S},B}(xs)
end
(m::SeqModel)(x::BatchSeq) = batchseq(rawbatch(m.m(x)))
function (m::SeqModel)(x::BatchSeq)
if isempty(m.state) || length(first(m.state)) length(x)
m.state = map(batchone, m.m.model.states)
end
output = m.m(m.state..., x)
m.state, output = output[1:end-1], output[end]
return batchseq(rawbatch(output))
end
(m::SeqModel)(x::Seq) = first(m(batchone(x)))

View File

@ -7,5 +7,6 @@ export tf
include("graph.jl")
include("model.jl")
include("recurrent.jl")
end

View File

@ -1,6 +1,8 @@
immutable ModelInput end
inputnode(n) = vertex(Split(n), constant(ModelInput()))
splitnode(v, n) = vertex(Split(n), v)
inputnode(n) = splitnode(constant(ModelInput()), n)
function bumpinputs(v::IVertex)
prewalk(v) do v

View File

@ -63,25 +63,26 @@ end
function unrollgraph(model, n)
graph, defaults = break!(atomise(model))
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
detuple(outputs[end])
outputs = [spliceinputs(graph, group([constant(splitnode(inputnode(1),i)) for i = 1:length(defaults)]...),
splitnode(inputnode(2), 1))]
for i = 2:n
push!(outputs, spliceinputs(graph, outputs[end][1], inputnode(i)))
push!(outputs, spliceinputs(graph, outputs[end][1], splitnode(inputnode(2), i)))
end
state = outputs[end][1]
outputs = map(x -> x[2], outputs)
@> group(state, group(outputs...)) detuple
(@> group(state, group(outputs...)) detuple), map(x->x.x, defaults)
end
type Unrolled <: Model
model
graph::IVertex{Any}
states::Vector{Any}
steps::Int
end
graph(u::Unrolled) = u.graph
unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
@net type Recurrent
Wxh; Whh; Why
@ -95,9 +96,9 @@ end
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
initn(hidden), initn(out), zeros(Float32, hidden)')
initn(hidden), initn(out), zeros(Float32, hidden))
# syntax(x) = syntax(Flow.dl(x), bindconst = true)
# r = Recurrent(10, 30, 20)
# unrollgraph(r,5) |> cse |> syntax |> prettify |> display
# r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10))
# unrollgraph(r,5)[1] |> syntax |> prettify |> clipboard