Merge pull request #34 from MikeInnes/broadcast

Updates for dataflow constants change
This commit is contained in:
Mike J Innes 2017-06-01 17:37:36 +01:00 committed by GitHub
commit 8454a7194c
7 changed files with 44 additions and 47 deletions

View File

@ -8,7 +8,7 @@ julia:
# uncomment the following lines to override the default test script
script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
- julia -e 'Pkg.clone(pwd()); Pkg.checkout("DataFlow"); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
after_success:
# - julia -e 'Pkg.add("Documenter")'
- julia -e 'Pkg.clone("https://github.com/MikeInnes/Documenter.jl")'

View File

@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
graph(::Input, x) = x
struct AlterParam
param
load
store
end
Base.size(p::AlterParam) = size(p.load(p.param.x))
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
graph(ctx::Context, d::Affine, x) =
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
register(ctx,
@ -74,19 +83,16 @@ end
register(ctx::Context, node) = node
function var(ctx::Context, p)
function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam})
id = gensym()
ctx[:params][id] = p
return mx.Variable(id)
end
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
graph(ctx::Context, p::Constant) = p.value
var(ctx::Context, x) = x
function graph(ctx::Context, model, args...)
args = var.(ctx, args)
g = Flux.graph(model)
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
@ -96,7 +102,7 @@ end
graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
function tograph(model, args...; feedforward = false)
ctx = Context(mux(iline, ilambda, iargs, ituple, graph),
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph),
params = Dict(), stacks = Dict(),
feedforward = feedforward)
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)

View File

@ -1,14 +1,5 @@
using Flux: collectt, shapecheckt
struct AlterParam
param
load
store
end
Base.size(p::AlterParam) = size(p.load(p.param.x))
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])

View File

@ -1,5 +1,5 @@
using Base: @get!
using DataFlow: Constant, constant, Split
using DataFlow: constant, Split
using DataFlow.Interpreter
using DataFlow.Interpreter: stack
using TensorFlow: RawTensor, TFException
@ -53,33 +53,34 @@ graph(p::MaxPool, x) =
graph(op::Op, xs...) = op.f(xs...)
function graph(ctx::Context, model, args...)
node = graph(model, interpv(ctx, args)...)
node = graph(model, args...)
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
return node
end
interp(ctx, c::Conv2D, x) =
nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID")
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ?
ctx[:params][p.value] :
(ctx[:params][p.value] =
param(ctx, p::Flux.Param{<:AArray}) =
haskey(ctx[:params], p) ?
ctx[:params][p] :
(ctx[:params][p] =
ctx[:variables] ?
Variable(Float32.(p.value.x)) :
Variable(Float32.(p.x)) :
placeholder(Float32))
interp(ctx, p::Constant) = p.value
param(ctx, x) = x
function interp(ctx, model, args...)
args = param.(ctx, args)
g = Flux.graph(model)
g == nothing && return graph(ctx, model, args...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, interpv(ctx, args)...)
interpret(ctx, g, args...)
end
function tograph(model, args...; variables = false)
ctx = Context(mux(iline, ilambda, interp),
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp),
params = ObjectIdDict(), stacks = Dict(), variables = variables)
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], ctx[:stacks], out

View File

@ -1,4 +1,4 @@
import DataFlow: mapconst, cse
import DataFlow: cse
using MacroTools: @q
export @net
@ -6,13 +6,13 @@ export @net
function graphdef(ex, params = [])
@capture(shortdef(ex), (args__,) -> body_)
body = @> body MacroTools.flatten liftloops graphm DataFlow.il
body = mapconst(x -> x in params ? :(self.$x) : x, body)
body = map(x -> x in params ? :(self.$x) : x, body)
return args, body
end
function makegraph(graph, args, params = [])
graph = prewalk(graph) do v
value(v) isa Constant && (i = findfirst(args, value(v).value)) 0 ?
isconstant(v) && (i = findfirst(args, value(v[1]))) 0 ?
inputnode(i) :
v
end
@ -42,7 +42,7 @@ end
function deref_params(v)
map(v) do x
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
end
end
@ -53,7 +53,7 @@ end
import Lazy: groupby
reifyparams(v::IVertex) = mapconst(x -> x isa Param ? x.x : x, v)
reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v)
# TODO: type hints for parameters
@ -69,14 +69,14 @@ function process_type(ex)
$(build_type(T, params))
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args, params))))
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
nothing
end
end
function process_anon(ex)
args, body = graphdef(ex)
:(Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args)[1])))))
:(Capacitor($(DataFlow.constructor(map(esc, makegraph(body, args)[1])))))
end
function process_def(ex)

View File

@ -1,5 +1,5 @@
function astuple(xs::Vertex)
isconstant(xs) && value(xs).value isa Tuple ? value(xs).value :
isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) :
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
nothing
end
@ -15,15 +15,13 @@ end
function interp(ctx, f, xs...)
g = graph(f)
g nothing && iscyclic(g) && error("Can't interpret cyclic graph")
@icatch(ctx, g nothing ?
interpret(ctx, reifyparams(g), xs...) :
f(xs...))
end
interp(ctx::Context, c::Constant{<:Param}) = c.value.x
interp(ctx::Context, c::Constant) = c.value
function interpmodel(m, args...)
ctx = Context(mux(iline, ilambda, iargs, ituple, interp))
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
@ithrow interp(ctx, m, args...)
end

View File

@ -8,6 +8,10 @@ end
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
arghint(p::Param) = arghint(state(p))
arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
arghint(x) = constant(x)
function gethint(v::IVertex)
while value(v) isa Union{Line,Frame} v = v[1] end
value(v) isa Hint && return value(v).typ
@ -17,19 +21,16 @@ end
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
ihint(f, args...) = f(args...)
hintify(ctx, c::Constant{<:Union{Param,AbstractArray}}) = hintify(ctx, state(c.value))
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
hintify(ctx, c::Constant) = vertex(c)
interpshape = mux(ilinev, ihint, iargs, hintify)
function hintify(ctx, f, xs...)
sh = infer(f, gethint.(xs)...)
xs = arghint.(xs)
sh = infer(f, map(gethint, xs)...)
sh nothing ? vertex(Hint(sh), vertex(f, xs...)) :
!any(x->x==nothing, xs) && graph(f) nothing ? interpret(Context(interpshape), graph(f), xs...) :
vertex(f, xs...)
end
interpshape = mux(ilinev, iconst, ihint, iargs, hintify)
function shapesv(f, args...)
(g = graph(f)) == nothing && return
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]