Merge pull request #34 from MikeInnes/broadcast
Updates for dataflow constants change
This commit is contained in:
commit
8454a7194c
@ -8,7 +8,7 @@ julia:
|
||||
# uncomment the following lines to override the default test script
|
||||
script:
|
||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
- julia -e 'Pkg.clone(pwd()); Pkg.checkout("DataFlow"); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
after_success:
|
||||
# - julia -e 'Pkg.add("Documenter")'
|
||||
- julia -e 'Pkg.clone("https://github.com/MikeInnes/Documenter.jl")'
|
||||
|
@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
struct AlterParam
|
||||
param
|
||||
load
|
||||
store
|
||||
end
|
||||
|
||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
||||
|
||||
graph(ctx::Context, d::Affine, x) =
|
||||
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
|
||||
register(ctx,
|
||||
@ -74,19 +83,16 @@ end
|
||||
|
||||
register(ctx::Context, node) = node
|
||||
|
||||
function var(ctx::Context, p)
|
||||
function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam})
|
||||
id = gensym()
|
||||
ctx[:params][id] = p
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
|
||||
|
||||
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
|
||||
|
||||
graph(ctx::Context, p::Constant) = p.value
|
||||
var(ctx::Context, x) = x
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
args = var.(ctx, args)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
@ -96,7 +102,7 @@ end
|
||||
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
||||
|
||||
function tograph(model, args...; feedforward = false)
|
||||
ctx = Context(mux(iline, ilambda, iargs, ituple, graph′),
|
||||
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′),
|
||||
params = Dict(), stacks = Dict(),
|
||||
feedforward = feedforward)
|
||||
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
||||
|
@ -1,14 +1,5 @@
|
||||
using Flux: collectt, shapecheckt
|
||||
|
||||
struct AlterParam
|
||||
param
|
||||
load
|
||||
store
|
||||
end
|
||||
|
||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
||||
|
||||
function copyargs!(as, bs)
|
||||
for id in intersect(keys(as), keys(bs))
|
||||
copy!(as[id], bs[id])
|
||||
|
@ -1,5 +1,5 @@
|
||||
using Base: @get!
|
||||
using DataFlow: Constant, constant, Split
|
||||
using DataFlow: constant, Split
|
||||
using DataFlow.Interpreter
|
||||
using DataFlow.Interpreter: stack
|
||||
using TensorFlow: RawTensor, TFException
|
||||
@ -53,33 +53,34 @@ graph(p::MaxPool, x) =
|
||||
graph(op::Op, xs...) = op.f(xs...)
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
node = graph(model, interpv(ctx, args)...)
|
||||
node = graph(model, args...)
|
||||
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
|
||||
return node
|
||||
end
|
||||
|
||||
interp(ctx, c::Conv2D, x) =
|
||||
nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||
nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||
|
||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||
haskey(ctx[:params], p.value) ?
|
||||
ctx[:params][p.value] :
|
||||
(ctx[:params][p.value] =
|
||||
param(ctx, p::Flux.Param{<:AArray}) =
|
||||
haskey(ctx[:params], p) ?
|
||||
ctx[:params][p] :
|
||||
(ctx[:params][p] =
|
||||
ctx[:variables] ?
|
||||
Variable(Float32.(p.value.x)) :
|
||||
Variable(Float32.(p.x)) :
|
||||
placeholder(Float32))
|
||||
|
||||
interp(ctx, p::Constant) = p.value
|
||||
param(ctx, x) = x
|
||||
|
||||
function interp(ctx, model, args...)
|
||||
args = param.(ctx, args)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(ctx, model, args...)
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
interpret(ctx, g, interpv(ctx, args)...)
|
||||
interpret(ctx, g, args...)
|
||||
end
|
||||
|
||||
function tograph(model, args...; variables = false)
|
||||
ctx = Context(mux(iline, ilambda, interp),
|
||||
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp),
|
||||
params = ObjectIdDict(), stacks = Dict(), variables = variables)
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
|
@ -1,4 +1,4 @@
|
||||
import DataFlow: mapconst, cse
|
||||
import DataFlow: cse
|
||||
using MacroTools: @q
|
||||
|
||||
export @net
|
||||
@ -6,13 +6,13 @@ export @net
|
||||
function graphdef(ex, params = [])
|
||||
@capture(shortdef(ex), (args__,) -> body_)
|
||||
body = @> body MacroTools.flatten liftloops graphm DataFlow.il
|
||||
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||
return args, body
|
||||
end
|
||||
|
||||
function makegraph(graph, args, params = [])
|
||||
graph = prewalk(graph) do v
|
||||
value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ?
|
||||
isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ?
|
||||
inputnode(i) :
|
||||
v
|
||||
end
|
||||
@ -42,7 +42,7 @@ end
|
||||
|
||||
function deref_params(v)
|
||||
map(v) do x
|
||||
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
||||
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
|
||||
end
|
||||
end
|
||||
|
||||
@ -53,7 +53,7 @@ end
|
||||
|
||||
import Lazy: groupby
|
||||
|
||||
reifyparams(v::IVertex) = mapconst(x -> x isa Param ? x.x : x, v)
|
||||
reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v)
|
||||
|
||||
# TODO: type hints for parameters
|
||||
|
||||
@ -69,14 +69,14 @@ function process_type(ex)
|
||||
$(build_type(T, params))
|
||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args, params))))
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
||||
function process_anon(ex)
|
||||
args, body = graphdef(ex)
|
||||
:(Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args)[1])))))
|
||||
:(Capacitor($(DataFlow.constructor(map(esc, makegraph(body, args)[1])))))
|
||||
end
|
||||
|
||||
function process_def(ex)
|
||||
|
@ -1,5 +1,5 @@
|
||||
function astuple(xs::Vertex)
|
||||
isconstant(xs) && value(xs).value isa Tuple ? value(xs).value :
|
||||
isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) :
|
||||
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
||||
nothing
|
||||
end
|
||||
@ -15,15 +15,13 @@ end
|
||||
|
||||
function interp(ctx, f, xs...)
|
||||
g = graph(f)
|
||||
g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph")
|
||||
@icatch(ctx, g ≠ nothing ?
|
||||
interpret(ctx, reifyparams(g), xs...) :
|
||||
f(xs...))
|
||||
end
|
||||
|
||||
interp(ctx::Context, c::Constant{<:Param}) = c.value.x
|
||||
interp(ctx::Context, c::Constant) = c.value
|
||||
|
||||
function interpmodel(m, args...)
|
||||
ctx = Context(mux(iline, ilambda, iargs, ituple, interp))
|
||||
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
||||
@ithrow interp(ctx, m, args...)
|
||||
end
|
||||
|
@ -8,6 +8,10 @@ end
|
||||
|
||||
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
|
||||
|
||||
arghint(p::Param) = arghint(state(p))
|
||||
arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||
arghint(x) = constant(x)
|
||||
|
||||
function gethint(v::IVertex)
|
||||
while value(v) isa Union{Line,Frame} v = v[1] end
|
||||
value(v) isa Hint && return value(v).typ
|
||||
@ -17,19 +21,16 @@ end
|
||||
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
||||
ihint(f, args...) = f(args...)
|
||||
|
||||
hintify(ctx, c::Constant{<:Union{Param,AbstractArray}}) = hintify(ctx, state(c.value))
|
||||
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||
hintify(ctx, c::Constant) = vertex(c)
|
||||
|
||||
interpshape = mux(ilinev, ihint, iargs, hintify)
|
||||
|
||||
function hintify(ctx, f, xs...)
|
||||
sh = infer(f, gethint.(xs)...)
|
||||
xs = arghint.(xs)
|
||||
sh = infer(f, map(gethint, xs)...)
|
||||
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
||||
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
||||
vertex(f, xs...)
|
||||
end
|
||||
|
||||
interpshape = mux(ilinev, iconst, ihint, iargs, hintify)
|
||||
|
||||
function shapesv(f, args...)
|
||||
(g = graph(f)) == nothing && return
|
||||
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
||||
|
Loading…
Reference in New Issue
Block a user