Flow -> DataFlow
This commit is contained in:
parent
7cd94b4a5d
commit
53ebb5051a
|
@ -10,5 +10,5 @@ notifications:
|
||||||
# uncomment the following lines to override the default test script
|
# uncomment the following lines to override the default test script
|
||||||
script:
|
script:
|
||||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||||
- julia -e 'Pkg.clone("https://github.com/MikeInnes/Flow.jl")'
|
- julia -e 'Pkg.clone("https://github.com/MikeInnes/DataFlow.jl")'
|
||||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
module Flux
|
module Flux
|
||||||
|
|
||||||
using MacroTools, Lazy, Flow, Juno
|
using MacroTools, Lazy, DataFlow, Juno
|
||||||
import Flow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
|
import DataFlow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
|
||||||
Constant, constant, isconstant, value, inputs, thread!, value, inputs,
|
Constant, constant, isconstant, value, inputs, thread!, value, inputs,
|
||||||
Split, Group, group
|
Split, Group, group
|
||||||
import Juno: Tree, Row
|
import Juno: Tree, Row
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import Base: @get!
|
import Base: @get!
|
||||||
import Flow: Constant, postwalk, value, inputs, constant
|
import DataFlow: Constant, postwalk, value, inputs, constant
|
||||||
import TensorFlow: RawTensor
|
import TensorFlow: RawTensor
|
||||||
|
|
||||||
cvalue(x) = x
|
cvalue(x) = x
|
||||||
|
@ -26,7 +26,7 @@ graph(::Input, x) = x
|
||||||
graph(p::MaxPool, x) =
|
graph(p::MaxPool, x) =
|
||||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||||
|
|
||||||
graph(::Flow.Group, xs...) = (xs...,)
|
graph(::DataFlow.Group, xs...) = (xs...,)
|
||||||
|
|
||||||
graph(params::Associative, c::Conv2D, x) =
|
graph(params::Associative, c::Conv2D, x) =
|
||||||
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
|
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
|
||||||
|
@ -55,7 +55,7 @@ end
|
||||||
function graph(params::Associative, model, args...)
|
function graph(params::Associative, model, args...)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return graph(model, args...)
|
g == nothing && return graph(model, args...)
|
||||||
Flow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||||
graph(params, g, args...)
|
graph(params, g, args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ type SeqModel
|
||||||
state::Any
|
state::Any
|
||||||
end
|
end
|
||||||
|
|
||||||
cgroup(xs...) = Flow.group(map(constant, xs)...)
|
cgroup(xs...) = DataFlow.group(map(constant, xs)...)
|
||||||
|
|
||||||
function makesession(model::Flux.Unrolled)
|
function makesession(model::Flux.Unrolled)
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
module TF
|
module TF
|
||||||
|
|
||||||
using ..Flux, Flow, TensorFlow, Juno
|
using ..Flux, DataFlow, TensorFlow, Juno
|
||||||
import Flux: accuracy, spliceinputs, detuple
|
import Flux: accuracy, spliceinputs, detuple
|
||||||
|
|
||||||
export tf
|
export tf
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
# TODO: proper escaping
|
# TODO: proper escaping
|
||||||
|
|
||||||
import Flow: mapconst, cse
|
import DataFlow: mapconst, cse
|
||||||
|
|
||||||
export @net
|
export @net
|
||||||
|
|
||||||
function process_func(ex, params = [])
|
function process_func(ex, params = [])
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
body = @> body MacroTools.flatten liftloops!(params) graphm Flow.il
|
body = @> body MacroTools.flatten liftloops!(params) graphm DataFlow.il
|
||||||
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
||||||
return args, body
|
return args, body
|
||||||
end
|
end
|
||||||
|
@ -48,7 +48,7 @@ end
|
||||||
function build_backward(body, x, params = [])
|
function build_backward(body, x, params = [])
|
||||||
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
|
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
|
||||||
Δs = invert(body)
|
Δs = invert(body)
|
||||||
back = IVertex{Any}(Flow.Do())
|
back = IVertex{Any}(DataFlow.Do())
|
||||||
for param in params
|
for param in params
|
||||||
haskey(Δs, :(self.$param)) || continue
|
haskey(Δs, :(self.$param)) || continue
|
||||||
ex = Δs[:(self.$param)]
|
ex = Δs[:(self.$param)]
|
||||||
|
@ -76,7 +76,7 @@ function process_type(ex)
|
||||||
(self::$T)($(args...),) = $(build_forward(body, args))
|
(self::$T)($(args...),) = $(build_forward(body, args))
|
||||||
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
||||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||||
graph(self::$T) = $(Flow.constructor(makegraph(body, args)))
|
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
|
||||||
nothing
|
nothing
|
||||||
end |> esc
|
end |> esc
|
||||||
end
|
end
|
||||||
|
@ -84,7 +84,7 @@ end
|
||||||
function process_anon(ex)
|
function process_anon(ex)
|
||||||
args, body = process_func(ex)
|
args, body = process_func(ex)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
:(Flux.Capacitor($(Flow.constructor(makegraph(body, args))))) |> esc
|
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args))))) |> esc
|
||||||
end
|
end
|
||||||
|
|
||||||
macro net(ex)
|
macro net(ex)
|
||||||
|
|
|
@ -8,7 +8,7 @@ symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
|
||||||
function ∇v(v::Vertex, Δ)
|
function ∇v(v::Vertex, Δ)
|
||||||
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
||||||
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
|
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
|
||||||
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
map(i -> @flow(getindex($Δ, $i)), 1:DataFlow.nin(v))
|
||||||
end
|
end
|
||||||
|
|
||||||
function invert(v::IVertex, Δ = :Δ, out = d())
|
function invert(v::IVertex, Δ = :Δ, out = d())
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# TODO: most (all?) of this could be in Flow
|
# TODO: most (all?) of this could be in DataFlow
|
||||||
|
|
||||||
immutable ModelInput end
|
immutable ModelInput end
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ end
|
||||||
Delay(name) = Delay(name, nothing)
|
Delay(name) = Delay(name, nothing)
|
||||||
|
|
||||||
function liftloops!(ex, params)
|
function liftloops!(ex, params)
|
||||||
ex = Flow.normedges(ex)
|
ex = DataFlow.normedges(ex)
|
||||||
hidden = intersect((b.args[1] for b in ex.args), params)
|
hidden = intersect((b.args[1] for b in ex.args), params)
|
||||||
edges = Dict(h => gensym("edge") for h in hidden)
|
edges = Dict(h => gensym("edge") for h in hidden)
|
||||||
declared = Dict(h => false for h in hidden)
|
declared = Dict(h => false for h in hidden)
|
||||||
|
|
Loading…
Reference in New Issue