handle state on julia side
This commit is contained in:
parent
e450a585b7
commit
d9ed5676c2
|
@ -1,9 +1,9 @@
|
|||
type Model
|
||||
model
|
||||
model::Any
|
||||
session::Session
|
||||
params::Dict{Flux.Param,Tensor}
|
||||
inputs::Vector{Tensor}
|
||||
output
|
||||
output::Any
|
||||
gradients::Vector{Tensor}
|
||||
end
|
||||
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
immutable SeqModel
|
||||
# TODO: refactor, some of this is more general than just the TF backend
|
||||
|
||||
type SeqModel
|
||||
m::Model
|
||||
state::Any
|
||||
end
|
||||
|
||||
cgroup(xs...) = Flow.group(map(constant, xs)...)
|
||||
|
||||
function tf(model::Flux.Unrolled)
|
||||
sess = Session(Graph())
|
||||
input = placeholder(Float32)
|
||||
instates = [placeholder(Float32) for _ in model.states]
|
||||
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
||||
params, (state, outputs) = tograph(model.graph, inputs...)
|
||||
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
|
||||
output = TensorFlow.pack(outputs, axis = 1)
|
||||
run(sess, initialize_all_variables())
|
||||
Model(model, sess, params,
|
||||
[input], [output],
|
||||
[gradients(output, input)]) |> SeqModel
|
||||
SeqModel(
|
||||
Model(model, sess, params,
|
||||
[instates..., input], [outstates..., output],
|
||||
[gradients(output, input)]),
|
||||
[])
|
||||
end
|
||||
|
||||
function batchseq(xs)
|
||||
|
@ -22,6 +30,13 @@ function batchseq(xs)
|
|||
Batch{Seq{T,S},B}(xs)
|
||||
end
|
||||
|
||||
(m::SeqModel)(x::BatchSeq) = batchseq(rawbatch(m.m(x)))
|
||||
function (m::SeqModel)(x::BatchSeq)
|
||||
if isempty(m.state) || length(first(m.state)) ≠ length(x)
|
||||
m.state = map(batchone, m.m.model.states)
|
||||
end
|
||||
output = m.m(m.state..., x)
|
||||
m.state, output = output[1:end-1], output[end]
|
||||
return batchseq(rawbatch(output))
|
||||
end
|
||||
|
||||
(m::SeqModel)(x::Seq) = first(m(batchone(x)))
|
||||
|
|
|
@ -7,5 +7,6 @@ export tf
|
|||
|
||||
include("graph.jl")
|
||||
include("model.jl")
|
||||
include("recurrent.jl")
|
||||
|
||||
end
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
immutable ModelInput end
|
||||
|
||||
inputnode(n) = vertex(Split(n), constant(ModelInput()))
|
||||
splitnode(v, n) = vertex(Split(n), v)
|
||||
|
||||
inputnode(n) = splitnode(constant(ModelInput()), n)
|
||||
|
||||
function bumpinputs(v::IVertex)
|
||||
prewalk(v) do v
|
||||
|
|
|
@ -63,25 +63,26 @@ end
|
|||
|
||||
function unrollgraph(model, n)
|
||||
graph, defaults = break!(atomise(model))
|
||||
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
|
||||
detuple(outputs[end])
|
||||
outputs = [spliceinputs(graph, group([constant(splitnode(inputnode(1),i)) for i = 1:length(defaults)]...),
|
||||
splitnode(inputnode(2), 1))]
|
||||
for i = 2:n
|
||||
push!(outputs, spliceinputs(graph, outputs[end][1], inputnode(i)))
|
||||
push!(outputs, spliceinputs(graph, outputs[end][1], splitnode(inputnode(2), i)))
|
||||
end
|
||||
state = outputs[end][1]
|
||||
outputs = map(x -> x[2], outputs)
|
||||
@> group(state, group(outputs...)) detuple
|
||||
(@> group(state, group(outputs...)) detuple), map(x->x.x, defaults)
|
||||
end
|
||||
|
||||
type Unrolled <: Model
|
||||
model
|
||||
graph::IVertex{Any}
|
||||
states::Vector{Any}
|
||||
steps::Int
|
||||
end
|
||||
|
||||
graph(u::Unrolled) = u.graph
|
||||
|
||||
unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
|
||||
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
|
||||
|
||||
@net type Recurrent
|
||||
Wxh; Whh; Why
|
||||
|
@ -95,9 +96,9 @@ end
|
|||
|
||||
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
||||
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
|
||||
initn(hidden), initn(out), zeros(Float32, hidden)')
|
||||
initn(hidden), initn(out), zeros(Float32, hidden))
|
||||
|
||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||
|
||||
# r = Recurrent(10, 30, 20)
|
||||
# unrollgraph(r,5) |> cse |> syntax′ |> prettify |> display
|
||||
# r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10))
|
||||
# unrollgraph(r,5)[1] |> syntax′ |> prettify |> clipboard
|
||||
|
|
Loading…
Reference in New Issue