Flow -> DataFlow

This commit is contained in:
Mike J Innes 2016-10-31 12:38:18 +00:00
parent 7cd94b4a5d
commit 53ebb5051a
10 changed files with 17 additions and 17 deletions

View File

@ -10,5 +10,5 @@ notifications:
# uncomment the following lines to override the default test script
script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- julia -e 'Pkg.clone("https://github.com/MikeInnes/Flow.jl")'
- julia -e 'Pkg.clone("https://github.com/MikeInnes/DataFlow.jl")'
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'

View File

@ -1,2 +1,2 @@
require:
MikeInnes/Flow
MikeInnes/DataFlow

View File

@ -1,7 +1,7 @@
module Flux
using MacroTools, Lazy, Flow, Juno
import Flow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
using MacroTools, Lazy, DataFlow, Juno
import DataFlow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
Constant, constant, isconstant, value, inputs, thread!, value, inputs,
Split, Group, group
import Juno: Tree, Row

View File

@ -1,5 +1,5 @@
import Base: @get!
import Flow: Constant, postwalk, value, inputs, constant
import DataFlow: Constant, postwalk, value, inputs, constant
import TensorFlow: RawTensor
cvalue(x) = x
@ -26,7 +26,7 @@ graph(::Input, x) = x
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
graph(::Flow.Group, xs...) = (xs...,)
graph(::DataFlow.Group, xs...) = (xs...,)
graph(params::Associative, c::Conv2D, x) =
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
@ -55,7 +55,7 @@ end
function graph(params::Associative, model, args...)
g = Flux.graph(model)
g == nothing && return graph(model, args...)
Flow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
graph(params, g, args...)
end

View File

@ -5,7 +5,7 @@ type SeqModel
state::Any
end
cgroup(xs...) = Flow.group(map(constant, xs)...)
cgroup(xs...) = DataFlow.group(map(constant, xs)...)
function makesession(model::Flux.Unrolled)
sess = Session(Graph())

View File

@ -1,6 +1,6 @@
module TF
using ..Flux, Flow, TensorFlow, Juno
using ..Flux, DataFlow, TensorFlow, Juno
import Flux: accuracy, spliceinputs, detuple
export tf

View File

@ -1,12 +1,12 @@
# TODO: proper escaping
import Flow: mapconst, cse
import DataFlow: mapconst, cse
export @net
function process_func(ex, params = [])
@capture(shortdef(ex), (args__,) -> body_)
body = @> body MacroTools.flatten liftloops!(params) graphm Flow.il
body = @> body MacroTools.flatten liftloops!(params) graphm DataFlow.il
body = mapconst(x -> x in params ? :(self.$x) : x, body)
return args, body
end
@ -48,7 +48,7 @@ end
function build_backward(body, x, params = [])
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
Δs = invert(body)
back = IVertex{Any}(Flow.Do())
back = IVertex{Any}(DataFlow.Do())
for param in params
haskey(Δs, :(self.$param)) || continue
ex = Δs[:(self.$param)]
@ -76,7 +76,7 @@ function process_type(ex)
(self::$T)($(args...),) = $(build_forward(body, args))
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
graph(self::$T) = $(Flow.constructor(makegraph(body, args)))
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
nothing
end |> esc
end
@ -84,7 +84,7 @@ end
function process_anon(ex)
args, body = process_func(ex)
@assert length(args) == 1
:(Flux.Capacitor($(Flow.constructor(makegraph(body, args))))) |> esc
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args))))) |> esc
end
macro net(ex)

View File

@ -8,7 +8,7 @@ symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
function ∇v(v::Vertex, Δ)
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
map(i -> @flow(getindex($Δ, $i)), 1:DataFlow.nin(v))
end
function invert(v::IVertex, Δ = , out = d())

View File

@ -1,4 +1,4 @@
# TODO: most (all?) of this could be in Flow
# TODO: most (all?) of this could be in DataFlow
immutable ModelInput end

View File

@ -8,7 +8,7 @@ end
Delay(name) = Delay(name, nothing)
function liftloops!(ex, params)
ex = Flow.normedges(ex)
ex = DataFlow.normedges(ex)
hidden = intersect((b.args[1] for b in ex.args), params)
edges = Dict(h => gensym("edge") for h in hidden)
declared = Dict(h => false for h in hidden)