Merge pull request #34 from MikeInnes/broadcast
Updates for dataflow constants change
This commit is contained in:
commit
8454a7194c
@ -8,7 +8,7 @@ julia:
|
|||||||
# uncomment the following lines to override the default test script
|
# uncomment the following lines to override the default test script
|
||||||
script:
|
script:
|
||||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
- julia -e 'Pkg.clone(pwd()); Pkg.checkout("DataFlow"); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||||
after_success:
|
after_success:
|
||||||
# - julia -e 'Pkg.add("Documenter")'
|
# - julia -e 'Pkg.add("Documenter")'
|
||||||
- julia -e 'Pkg.clone("https://github.com/MikeInnes/Documenter.jl")'
|
- julia -e 'Pkg.clone("https://github.com/MikeInnes/Documenter.jl")'
|
||||||
|
@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
|||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
|
struct AlterParam
|
||||||
|
param
|
||||||
|
load
|
||||||
|
store
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||||
|
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
||||||
|
|
||||||
graph(ctx::Context, d::Affine, x) =
|
graph(ctx::Context, d::Affine, x) =
|
||||||
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
|
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
|
||||||
register(ctx,
|
register(ctx,
|
||||||
@ -74,19 +83,16 @@ end
|
|||||||
|
|
||||||
register(ctx::Context, node) = node
|
register(ctx::Context, node) = node
|
||||||
|
|
||||||
function var(ctx::Context, p)
|
function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam})
|
||||||
id = gensym()
|
id = gensym()
|
||||||
ctx[:params][id] = p
|
ctx[:params][id] = p
|
||||||
return mx.Variable(id)
|
return mx.Variable(id)
|
||||||
end
|
end
|
||||||
|
|
||||||
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
|
var(ctx::Context, x) = x
|
||||||
|
|
||||||
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
|
|
||||||
|
|
||||||
graph(ctx::Context, p::Constant) = p.value
|
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
function graph(ctx::Context, model, args...)
|
||||||
|
args = var.(ctx, args)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
||||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||||
@ -96,7 +102,7 @@ end
|
|||||||
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
||||||
|
|
||||||
function tograph(model, args...; feedforward = false)
|
function tograph(model, args...; feedforward = false)
|
||||||
ctx = Context(mux(iline, ilambda, iargs, ituple, graph′),
|
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′),
|
||||||
params = Dict(), stacks = Dict(),
|
params = Dict(), stacks = Dict(),
|
||||||
feedforward = feedforward)
|
feedforward = feedforward)
|
||||||
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
||||||
|
@ -1,14 +1,5 @@
|
|||||||
using Flux: collectt, shapecheckt
|
using Flux: collectt, shapecheckt
|
||||||
|
|
||||||
struct AlterParam
|
|
||||||
param
|
|
||||||
load
|
|
||||||
store
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
|
||||||
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
|
||||||
|
|
||||||
function copyargs!(as, bs)
|
function copyargs!(as, bs)
|
||||||
for id in intersect(keys(as), keys(bs))
|
for id in intersect(keys(as), keys(bs))
|
||||||
copy!(as[id], bs[id])
|
copy!(as[id], bs[id])
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Base: @get!
|
using Base: @get!
|
||||||
using DataFlow: Constant, constant, Split
|
using DataFlow: constant, Split
|
||||||
using DataFlow.Interpreter
|
using DataFlow.Interpreter
|
||||||
using DataFlow.Interpreter: stack
|
using DataFlow.Interpreter: stack
|
||||||
using TensorFlow: RawTensor, TFException
|
using TensorFlow: RawTensor, TFException
|
||||||
@ -53,33 +53,34 @@ graph(p::MaxPool, x) =
|
|||||||
graph(op::Op, xs...) = op.f(xs...)
|
graph(op::Op, xs...) = op.f(xs...)
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
function graph(ctx::Context, model, args...)
|
||||||
node = graph(model, interpv(ctx, args)...)
|
node = graph(model, args...)
|
||||||
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
|
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
|
||||||
return node
|
return node
|
||||||
end
|
end
|
||||||
|
|
||||||
interp(ctx, c::Conv2D, x) =
|
interp(ctx, c::Conv2D, x) =
|
||||||
nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||||
|
|
||||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
param(ctx, p::Flux.Param{<:AArray}) =
|
||||||
haskey(ctx[:params], p.value) ?
|
haskey(ctx[:params], p) ?
|
||||||
ctx[:params][p.value] :
|
ctx[:params][p] :
|
||||||
(ctx[:params][p.value] =
|
(ctx[:params][p] =
|
||||||
ctx[:variables] ?
|
ctx[:variables] ?
|
||||||
Variable(Float32.(p.value.x)) :
|
Variable(Float32.(p.x)) :
|
||||||
placeholder(Float32))
|
placeholder(Float32))
|
||||||
|
|
||||||
interp(ctx, p::Constant) = p.value
|
param(ctx, x) = x
|
||||||
|
|
||||||
function interp(ctx, model, args...)
|
function interp(ctx, model, args...)
|
||||||
|
args = param.(ctx, args)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return graph(ctx, model, args...)
|
g == nothing && return graph(ctx, model, args...)
|
||||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||||
interpret(ctx, g, interpv(ctx, args)...)
|
interpret(ctx, g, args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...; variables = false)
|
function tograph(model, args...; variables = false)
|
||||||
ctx = Context(mux(iline, ilambda, interp),
|
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp),
|
||||||
params = ObjectIdDict(), stacks = Dict(), variables = variables)
|
params = ObjectIdDict(), stacks = Dict(), variables = variables)
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], ctx[:stacks], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import DataFlow: mapconst, cse
|
import DataFlow: cse
|
||||||
using MacroTools: @q
|
using MacroTools: @q
|
||||||
|
|
||||||
export @net
|
export @net
|
||||||
@ -6,13 +6,13 @@ export @net
|
|||||||
function graphdef(ex, params = [])
|
function graphdef(ex, params = [])
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
body = @> body MacroTools.flatten liftloops graphm DataFlow.il
|
body = @> body MacroTools.flatten liftloops graphm DataFlow.il
|
||||||
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||||
return args, body
|
return args, body
|
||||||
end
|
end
|
||||||
|
|
||||||
function makegraph(graph, args, params = [])
|
function makegraph(graph, args, params = [])
|
||||||
graph = prewalk(graph) do v
|
graph = prewalk(graph) do v
|
||||||
value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ?
|
isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ?
|
||||||
inputnode(i) :
|
inputnode(i) :
|
||||||
v
|
v
|
||||||
end
|
end
|
||||||
@ -42,7 +42,7 @@ end
|
|||||||
|
|
||||||
function deref_params(v)
|
function deref_params(v)
|
||||||
map(v) do x
|
map(v) do x
|
||||||
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ end
|
|||||||
|
|
||||||
import Lazy: groupby
|
import Lazy: groupby
|
||||||
|
|
||||||
reifyparams(v::IVertex) = mapconst(x -> x isa Param ? x.x : x, v)
|
reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v)
|
||||||
|
|
||||||
# TODO: type hints for parameters
|
# TODO: type hints for parameters
|
||||||
|
|
||||||
@ -69,14 +69,14 @@ function process_type(ex)
|
|||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args, params))))
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
||||||
nothing
|
nothing
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function process_anon(ex)
|
function process_anon(ex)
|
||||||
args, body = graphdef(ex)
|
args, body = graphdef(ex)
|
||||||
:(Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args)[1])))))
|
:(Capacitor($(DataFlow.constructor(map(esc, makegraph(body, args)[1])))))
|
||||||
end
|
end
|
||||||
|
|
||||||
function process_def(ex)
|
function process_def(ex)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
function astuple(xs::Vertex)
|
function astuple(xs::Vertex)
|
||||||
isconstant(xs) && value(xs).value isa Tuple ? value(xs).value :
|
isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) :
|
||||||
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
||||||
nothing
|
nothing
|
||||||
end
|
end
|
||||||
@ -15,15 +15,13 @@ end
|
|||||||
|
|
||||||
function interp(ctx, f, xs...)
|
function interp(ctx, f, xs...)
|
||||||
g = graph(f)
|
g = graph(f)
|
||||||
|
g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph")
|
||||||
@icatch(ctx, g ≠ nothing ?
|
@icatch(ctx, g ≠ nothing ?
|
||||||
interpret(ctx, reifyparams(g), xs...) :
|
interpret(ctx, reifyparams(g), xs...) :
|
||||||
f(xs...))
|
f(xs...))
|
||||||
end
|
end
|
||||||
|
|
||||||
interp(ctx::Context, c::Constant{<:Param}) = c.value.x
|
|
||||||
interp(ctx::Context, c::Constant) = c.value
|
|
||||||
|
|
||||||
function interpmodel(m, args...)
|
function interpmodel(m, args...)
|
||||||
ctx = Context(mux(iline, ilambda, iargs, ituple, interp))
|
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
||||||
@ithrow interp(ctx, m, args...)
|
@ithrow interp(ctx, m, args...)
|
||||||
end
|
end
|
||||||
|
@ -8,6 +8,10 @@ end
|
|||||||
|
|
||||||
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
|
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
|
||||||
|
|
||||||
|
arghint(p::Param) = arghint(state(p))
|
||||||
|
arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||||
|
arghint(x) = constant(x)
|
||||||
|
|
||||||
function gethint(v::IVertex)
|
function gethint(v::IVertex)
|
||||||
while value(v) isa Union{Line,Frame} v = v[1] end
|
while value(v) isa Union{Line,Frame} v = v[1] end
|
||||||
value(v) isa Hint && return value(v).typ
|
value(v) isa Hint && return value(v).typ
|
||||||
@ -17,19 +21,16 @@ end
|
|||||||
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
||||||
ihint(f, args...) = f(args...)
|
ihint(f, args...) = f(args...)
|
||||||
|
|
||||||
hintify(ctx, c::Constant{<:Union{Param,AbstractArray}}) = hintify(ctx, state(c.value))
|
|
||||||
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
|
||||||
hintify(ctx, c::Constant) = vertex(c)
|
|
||||||
|
|
||||||
interpshape = mux(ilinev, ihint, iargs, hintify)
|
|
||||||
|
|
||||||
function hintify(ctx, f, xs...)
|
function hintify(ctx, f, xs...)
|
||||||
sh = infer(f, gethint.(xs)...)
|
xs = arghint.(xs)
|
||||||
|
sh = infer(f, map(gethint, xs)...)
|
||||||
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
||||||
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
||||||
vertex(f, xs...)
|
vertex(f, xs...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
interpshape = mux(ilinev, iconst, ihint, iargs, hintify)
|
||||||
|
|
||||||
function shapesv(f, args...)
|
function shapesv(f, args...)
|
||||||
(g = graph(f)) == nothing && return
|
(g = graph(f)) == nothing && return
|
||||||
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
||||||
|
Loading…
Reference in New Issue
Block a user