move towards abstract interpreter model

This commit is contained in:
Mike J Innes 2016-11-13 20:27:20 +00:00
parent 6ac4dd8429
commit c654fe403a
3 changed files with 25 additions and 33 deletions

View File

@ -1,15 +1,12 @@
import Base: @get!
import DataFlow: Constant, postwalk, value, inputs, constant
import TensorFlow: RawTensor
using Base: @get!
using DataFlow: Constant, constant, Context, interpret, interptuple
using TensorFlow: RawTensor
# TODO: implement Julia's type promotion rules
cvalue(x) = x
cvalue(c::Constant) = c.value
cvalue(v::Vertex) = cvalue(value(v))
graph(x::Tensor) = x
graph(x::Number) = TensorFlow.constant(Float32(x))
node(x::Tensor) = x
node(x::Tuple) = map(node, x)
node(x::Number) = TensorFlow.constant(Float32(x))
graph(::typeof(*), args...) = *(args...)
graph(::typeof(.*), args...) = .*(args...)
@ -30,39 +27,32 @@ graph(::Input, x) = x
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
graph(::DataFlow.Group, xs...) = (xs...,)
graph(params::Associative, c::Conv2D, x) =
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
graph(op::Op, xs...) = op.f(xs...)
Flux.shape(op::Op, d...) = op.shape(d...)
graph{T<:AArray}(params::Associative, p::Flux.Param{T}) =
@get!(params, p, Variable(p.x))
interp(ctx, c::Conv2D, x) =
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
function graph(params::Associative, v::IVertex, args...)
# TODO: check number of arguments
v = spliceinputs(v, map(constant, args)...) |> detuple
postwalk(v) do v
vertex(graph(params, cvalue(v), cvalue.(inputs(v))...))
end |> value
end
interp(ctx, c::Constant) = node(c.value)
function graph(params::Associative, model, args...)
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ?
ctx[:params][p.value] :
(ctx[:params][p.value] = Variable(p.value.x))
function interp(ctx, model, args...)
g = Flux.graph(model)
g == nothing && return graph(model, args...)
g == nothing && return graph(model, interpret(ctx, args)...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
graph(params, g, args...)
interpret(ctx, g, interpret(ctx, args)...)
end
function tograph(model, args...)
params = Dict{Flux.Param,Tensor}()
g = graph(params, model, args...)
return params, g
ctx = Context(interptuple(interp), params = ObjectIdDict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
end
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...)
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -12,7 +12,7 @@ function makesession(model::Flux.Unrolled)
input = placeholder(Float32)
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
instates = [placeholder(Float32) for _ in model.state]
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
output = TensorFlow.pack(outputs, axis = 1)
run(sess, initialize_all_variables())
sess, params, (instates, input), (outstates, output)

View File

@ -1,7 +1,7 @@
module TF
using ..Flux, DataFlow, TensorFlow, Juno
import Flux: accuracy, spliceinputs, detuple
import Flux: accuracy
export tf
@ -12,6 +12,8 @@ end
Op(f) = Op(f, (d...) -> nothing)
Flux.shape(op::Op, d...) = op.shape(d...)
include("graph.jl")
include("model.jl")
include("recurrent.jl")