handle state on julia side
This commit is contained in:
parent
e450a585b7
commit
d9ed5676c2
@ -1,9 +1,9 @@
|
|||||||
type Model
|
type Model
|
||||||
model
|
model::Any
|
||||||
session::Session
|
session::Session
|
||||||
params::Dict{Flux.Param,Tensor}
|
params::Dict{Flux.Param,Tensor}
|
||||||
inputs::Vector{Tensor}
|
inputs::Vector{Tensor}
|
||||||
output
|
output::Any
|
||||||
gradients::Vector{Tensor}
|
gradients::Vector{Tensor}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,17 +1,25 @@
|
|||||||
immutable SeqModel
|
# TODO: refactor, some of this is more general than just the TF backend
|
||||||
|
|
||||||
|
type SeqModel
|
||||||
m::Model
|
m::Model
|
||||||
|
state::Any
|
||||||
end
|
end
|
||||||
|
|
||||||
|
cgroup(xs...) = Flow.group(map(constant, xs)...)
|
||||||
|
|
||||||
function tf(model::Flux.Unrolled)
|
function tf(model::Flux.Unrolled)
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
input = placeholder(Float32)
|
input = placeholder(Float32)
|
||||||
|
instates = [placeholder(Float32) for _ in model.states]
|
||||||
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
||||||
params, (state, outputs) = tograph(model.graph, inputs...)
|
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
|
||||||
output = TensorFlow.pack(outputs, axis = 1)
|
output = TensorFlow.pack(outputs, axis = 1)
|
||||||
run(sess, initialize_all_variables())
|
run(sess, initialize_all_variables())
|
||||||
Model(model, sess, params,
|
SeqModel(
|
||||||
[input], [output],
|
Model(model, sess, params,
|
||||||
[gradients(output, input)]) |> SeqModel
|
[instates..., input], [outstates..., output],
|
||||||
|
[gradients(output, input)]),
|
||||||
|
[])
|
||||||
end
|
end
|
||||||
|
|
||||||
function batchseq(xs)
|
function batchseq(xs)
|
||||||
@ -22,6 +30,13 @@ function batchseq(xs)
|
|||||||
Batch{Seq{T,S},B}(xs)
|
Batch{Seq{T,S},B}(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::SeqModel)(x::BatchSeq) = batchseq(rawbatch(m.m(x)))
|
function (m::SeqModel)(x::BatchSeq)
|
||||||
|
if isempty(m.state) || length(first(m.state)) ≠ length(x)
|
||||||
|
m.state = map(batchone, m.m.model.states)
|
||||||
|
end
|
||||||
|
output = m.m(m.state..., x)
|
||||||
|
m.state, output = output[1:end-1], output[end]
|
||||||
|
return batchseq(rawbatch(output))
|
||||||
|
end
|
||||||
|
|
||||||
(m::SeqModel)(x::Seq) = first(m(batchone(x)))
|
(m::SeqModel)(x::Seq) = first(m(batchone(x)))
|
||||||
|
@ -7,5 +7,6 @@ export tf
|
|||||||
|
|
||||||
include("graph.jl")
|
include("graph.jl")
|
||||||
include("model.jl")
|
include("model.jl")
|
||||||
|
include("recurrent.jl")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
immutable ModelInput end
|
immutable ModelInput end
|
||||||
|
|
||||||
inputnode(n) = vertex(Split(n), constant(ModelInput()))
|
splitnode(v, n) = vertex(Split(n), v)
|
||||||
|
|
||||||
|
inputnode(n) = splitnode(constant(ModelInput()), n)
|
||||||
|
|
||||||
function bumpinputs(v::IVertex)
|
function bumpinputs(v::IVertex)
|
||||||
prewalk(v) do v
|
prewalk(v) do v
|
||||||
|
@ -63,25 +63,26 @@ end
|
|||||||
|
|
||||||
function unrollgraph(model, n)
|
function unrollgraph(model, n)
|
||||||
graph, defaults = break!(atomise(model))
|
graph, defaults = break!(atomise(model))
|
||||||
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
|
outputs = [spliceinputs(graph, group([constant(splitnode(inputnode(1),i)) for i = 1:length(defaults)]...),
|
||||||
detuple(outputs[end])
|
splitnode(inputnode(2), 1))]
|
||||||
for i = 2:n
|
for i = 2:n
|
||||||
push!(outputs, spliceinputs(graph, outputs[end][1], inputnode(i)))
|
push!(outputs, spliceinputs(graph, outputs[end][1], splitnode(inputnode(2), i)))
|
||||||
end
|
end
|
||||||
state = outputs[end][1]
|
state = outputs[end][1]
|
||||||
outputs = map(x -> x[2], outputs)
|
outputs = map(x -> x[2], outputs)
|
||||||
@> group(state, group(outputs...)) detuple
|
(@> group(state, group(outputs...)) detuple), map(x->x.x, defaults)
|
||||||
end
|
end
|
||||||
|
|
||||||
type Unrolled <: Model
|
type Unrolled <: Model
|
||||||
model
|
model
|
||||||
graph::IVertex{Any}
|
graph::IVertex{Any}
|
||||||
|
states::Vector{Any}
|
||||||
steps::Int
|
steps::Int
|
||||||
end
|
end
|
||||||
|
|
||||||
graph(u::Unrolled) = u.graph
|
graph(u::Unrolled) = u.graph
|
||||||
|
|
||||||
unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
|
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
|
||||||
|
|
||||||
@net type Recurrent
|
@net type Recurrent
|
||||||
Wxh; Whh; Why
|
Wxh; Whh; Why
|
||||||
@ -95,9 +96,9 @@ end
|
|||||||
|
|
||||||
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
||||||
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
|
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
|
||||||
initn(hidden), initn(out), zeros(Float32, hidden)')
|
initn(hidden), initn(out), zeros(Float32, hidden))
|
||||||
|
|
||||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||||
|
|
||||||
# r = Recurrent(10, 30, 20)
|
# r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10))
|
||||||
# unrollgraph(r,5) |> cse |> syntax′ |> prettify |> display
|
# unrollgraph(r,5)[1] |> syntax′ |> prettify |> clipboard
|
||||||
|
Loading…
Reference in New Issue
Block a user