move towards abstract interpreter model
This commit is contained in:
parent
6ac4dd8429
commit
c654fe403a
|
@ -1,15 +1,12 @@
|
|||
import Base: @get!
|
||||
import DataFlow: Constant, postwalk, value, inputs, constant
|
||||
import TensorFlow: RawTensor
|
||||
using Base: @get!
|
||||
using DataFlow: Constant, constant, Context, interpret, interptuple
|
||||
using TensorFlow: RawTensor
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
cvalue(x) = x
|
||||
cvalue(c::Constant) = c.value
|
||||
cvalue(v::Vertex) = cvalue(value(v))
|
||||
|
||||
graph(x::Tensor) = x
|
||||
graph(x::Number) = TensorFlow.constant(Float32(x))
|
||||
node(x::Tensor) = x
|
||||
node(x::Tuple) = map(node, x)
|
||||
node(x::Number) = TensorFlow.constant(Float32(x))
|
||||
|
||||
graph(::typeof(*), args...) = *(args...)
|
||||
graph(::typeof(.*), args...) = .*(args...)
|
||||
|
@ -30,39 +27,32 @@ graph(::Input, x) = x
|
|||
graph(p::MaxPool, x) =
|
||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||
|
||||
graph(::DataFlow.Group, xs...) = (xs...,)
|
||||
|
||||
graph(params::Associative, c::Conv2D, x) =
|
||||
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
|
||||
|
||||
graph(op::Op, xs...) = op.f(xs...)
|
||||
Flux.shape(op::Op, d...) = op.shape(d...)
|
||||
|
||||
graph{T<:AArray}(params::Associative, p::Flux.Param{T}) =
|
||||
@get!(params, p, Variable(p.x))
|
||||
interp(ctx, c::Conv2D, x) =
|
||||
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||
|
||||
function graph(params::Associative, v::IVertex, args...)
|
||||
# TODO: check number of arguments
|
||||
v = spliceinputs(v, map(constant, args)...) |> detuple
|
||||
postwalk(v) do v
|
||||
vertex(graph(params, cvalue(v), cvalue.(inputs(v))...))
|
||||
end |> value
|
||||
end
|
||||
interp(ctx, c::Constant) = node(c.value)
|
||||
|
||||
function graph(params::Associative, model, args...)
|
||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||
haskey(ctx[:params], p.value) ?
|
||||
ctx[:params][p.value] :
|
||||
(ctx[:params][p.value] = Variable(p.value.x))
|
||||
|
||||
function interp(ctx, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(model, args...)
|
||||
g == nothing && return graph(model, interpret(ctx, args)...)
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
graph(params, g, args...)
|
||||
interpret(ctx, g, interpret(ctx, args)...)
|
||||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
params = Dict{Flux.Param,Tensor}()
|
||||
g = graph(params, model, args...)
|
||||
return params, g
|
||||
ctx = Context(interptuple(interp), params = ObjectIdDict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], out
|
||||
end
|
||||
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...)
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ function makesession(model::Flux.Unrolled)
|
|||
input = placeholder(Float32)
|
||||
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
||||
instates = [placeholder(Float32) for _ in model.state]
|
||||
params, (outstates, outputs) = tograph(model.graph, cgroup(instates...), cgroup(inputs...))
|
||||
params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
|
||||
output = TensorFlow.pack(outputs, axis = 1)
|
||||
run(sess, initialize_all_variables())
|
||||
sess, params, (instates, input), (outstates, output)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
module TF
|
||||
|
||||
using ..Flux, DataFlow, TensorFlow, Juno
|
||||
import Flux: accuracy, spliceinputs, detuple
|
||||
import Flux: accuracy
|
||||
|
||||
export tf
|
||||
|
||||
|
@ -12,6 +12,8 @@ end
|
|||
|
||||
Op(f) = Op(f, (d...) -> nothing)
|
||||
|
||||
Flux.shape(op::Op, d...) = op.shape(d...)
|
||||
|
||||
include("graph.jl")
|
||||
include("model.jl")
|
||||
include("recurrent.jl")
|
||||
|
|
Loading…
Reference in New Issue