move towards abstract interpreter model

This commit is contained in:
Mike J Innes 2016-11-13 20:27:20 +00:00
parent 6ac4dd8429
commit c654fe403a
3 changed files with 25 additions and 33 deletions

View File

@ -1,15 +1,12 @@
import Base: @get! using Base: @get!
import DataFlow: Constant, postwalk, value, inputs, constant using DataFlow: Constant, constant, Context, interpret, interptuple
import TensorFlow: RawTensor using TensorFlow: RawTensor
# TODO: implement Julia's type promotion rules # TODO: implement Julia's type promotion rules
cvalue(x) = x node(x::Tensor) = x
cvalue(c::Constant) = c.value node(x::Tuple) = map(node, x)
cvalue(v::Vertex) = cvalue(value(v)) node(x::Number) = TensorFlow.constant(Float32(x))
graph(x::Tensor) = x
graph(x::Number) = TensorFlow.constant(Float32(x))
graph(::typeof(*), args...) = *(args...) graph(::typeof(*), args...) = *(args...)
graph(::typeof(.*), args...) = .*(args...) graph(::typeof(.*), args...) = .*(args...)
@ -30,39 +27,32 @@ graph(::Input, x) = x
graph(p::MaxPool, x) = graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID") nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
graph(::DataFlow.Group, xs...) = (xs...,)
graph(params::Associative, c::Conv2D, x) =
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
graph(op::Op, xs...) = op.f(xs...) graph(op::Op, xs...) = op.f(xs...)
Flux.shape(op::Op, d...) = op.shape(d...)
graph{T<:AArray}(params::Associative, p::Flux.Param{T}) = interp(ctx, c::Conv2D, x) =
@get!(params, p, Variable(p.x)) nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
function graph(params::Associative, v::IVertex, args...) interp(ctx, c::Constant) = node(c.value)
# TODO: check number of arguments
v = spliceinputs(v, map(constant, args)...) |> detuple
postwalk(v) do v
vertex(graph(params, cvalue(v), cvalue.(inputs(v))...))
end |> value
end
function graph(params::Associative, model, args...) interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ?
ctx[:params][p.value] :
(ctx[:params][p.value] = Variable(p.value.x))
function interp(ctx, model, args...)
g = Flux.graph(model) g = Flux.graph(model)
g == nothing && return graph(model, args...) g == nothing && return graph(model, interpret(ctx, args)...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
graph(params, g, args...) interpret(ctx, g, interpret(ctx, args)...)
end end
function tograph(model, args...) function tograph(model, args...)
params = Dict{Flux.Param,Tensor}() ctx = Context(interptuple(interp), params = ObjectIdDict())
g = graph(params, model, args...) out = interp(ctx, model, map(constant, args)...)
return params, g return ctx[:params], out
end end
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...) TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -12,7 +12,7 @@ function makesession(model::Flux.Unrolled)
input = placeholder(Float32) input = placeholder(Float32)
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1) inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
instates = [placeholder(Float32) for _ in model.state] instates = [placeholder(Float32) for _ in model.state]
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...)) params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
output = TensorFlow.pack(outputs, axis = 1) output = TensorFlow.pack(outputs, axis = 1)
run(sess, initialize_all_variables()) run(sess, initialize_all_variables())
sess, params, (instates, input), (outstates, output) sess, params, (instates, input), (outstates, output)

View File

@ -1,7 +1,7 @@
module TF module TF
using ..Flux, DataFlow, TensorFlow, Juno using ..Flux, DataFlow, TensorFlow, Juno
import Flux: accuracy, spliceinputs, detuple import Flux: accuracy
export tf export tf
@ -12,6 +12,8 @@ end
Op(f) = Op(f, (d...) -> nothing) Op(f) = Op(f, (d...) -> nothing)
Flux.shape(op::Op, d...) = op.shape(d...)
include("graph.jl") include("graph.jl")
include("model.jl") include("model.jl")
include("recurrent.jl") include("recurrent.jl")