Flow -> DataFlow
This commit is contained in:
parent
7cd94b4a5d
commit
53ebb5051a
|
@ -10,5 +10,5 @@ notifications:
|
|||
# uncomment the following lines to override the default test script
|
||||
script:
|
||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
- julia -e 'Pkg.clone("https://github.com/MikeInnes/Flow.jl")'
|
||||
- julia -e 'Pkg.clone("https://github.com/MikeInnes/DataFlow.jl")'
|
||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
module Flux
|
||||
|
||||
using MacroTools, Lazy, Flow, Juno
|
||||
import Flow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
|
||||
using MacroTools, Lazy, DataFlow, Juno
|
||||
import DataFlow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
|
||||
Constant, constant, isconstant, value, inputs, thread!, value, inputs,
|
||||
Split, Group, group
|
||||
import Juno: Tree, Row
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import Base: @get!
|
||||
import Flow: Constant, postwalk, value, inputs, constant
|
||||
import DataFlow: Constant, postwalk, value, inputs, constant
|
||||
import TensorFlow: RawTensor
|
||||
|
||||
cvalue(x) = x
|
||||
|
@ -26,7 +26,7 @@ graph(::Input, x) = x
|
|||
graph(p::MaxPool, x) =
|
||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||
|
||||
graph(::Flow.Group, xs...) = (xs...,)
|
||||
graph(::DataFlow.Group, xs...) = (xs...,)
|
||||
|
||||
graph(params::Associative, c::Conv2D, x) =
|
||||
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
|
||||
|
@ -55,7 +55,7 @@ end
|
|||
function graph(params::Associative, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(model, args...)
|
||||
Flow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
graph(params, g, args...)
|
||||
end
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ type SeqModel
|
|||
state::Any
|
||||
end
|
||||
|
||||
cgroup(xs...) = Flow.group(map(constant, xs)...)
|
||||
cgroup(xs...) = DataFlow.group(map(constant, xs)...)
|
||||
|
||||
function makesession(model::Flux.Unrolled)
|
||||
sess = Session(Graph())
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
module TF
|
||||
|
||||
using ..Flux, Flow, TensorFlow, Juno
|
||||
using ..Flux, DataFlow, TensorFlow, Juno
|
||||
import Flux: accuracy, spliceinputs, detuple
|
||||
|
||||
export tf
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# TODO: proper escaping
|
||||
|
||||
import Flow: mapconst, cse
|
||||
import DataFlow: mapconst, cse
|
||||
|
||||
export @net
|
||||
|
||||
function process_func(ex, params = [])
|
||||
@capture(shortdef(ex), (args__,) -> body_)
|
||||
body = @> body MacroTools.flatten liftloops!(params) graphm Flow.il
|
||||
body = @> body MacroTools.flatten liftloops!(params) graphm DataFlow.il
|
||||
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
||||
return args, body
|
||||
end
|
||||
|
@ -48,7 +48,7 @@ end
|
|||
function build_backward(body, x, params = [])
|
||||
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
|
||||
Δs = invert(body)
|
||||
back = IVertex{Any}(Flow.Do())
|
||||
back = IVertex{Any}(DataFlow.Do())
|
||||
for param in params
|
||||
haskey(Δs, :(self.$param)) || continue
|
||||
ex = Δs[:(self.$param)]
|
||||
|
@ -76,7 +76,7 @@ function process_type(ex)
|
|||
(self::$T)($(args...),) = $(build_forward(body, args))
|
||||
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||
graph(self::$T) = $(Flow.constructor(makegraph(body, args)))
|
||||
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
|
||||
nothing
|
||||
end |> esc
|
||||
end
|
||||
|
@ -84,7 +84,7 @@ end
|
|||
function process_anon(ex)
|
||||
args, body = process_func(ex)
|
||||
@assert length(args) == 1
|
||||
:(Flux.Capacitor($(Flow.constructor(makegraph(body, args))))) |> esc
|
||||
:(Flux.Capacitor($(DataFlow.constructor(makegraph(body, args))))) |> esc
|
||||
end
|
||||
|
||||
macro net(ex)
|
||||
|
|
|
@ -8,7 +8,7 @@ symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
|
|||
function ∇v(v::Vertex, Δ)
|
||||
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
||||
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
|
||||
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
||||
map(i -> @flow(getindex($Δ, $i)), 1:DataFlow.nin(v))
|
||||
end
|
||||
|
||||
function invert(v::IVertex, Δ = :Δ, out = d())
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# TODO: most (all?) of this could be in Flow
|
||||
# TODO: most (all?) of this could be in DataFlow
|
||||
|
||||
immutable ModelInput end
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ end
|
|||
Delay(name) = Delay(name, nothing)
|
||||
|
||||
function liftloops!(ex, params)
|
||||
ex = Flow.normedges(ex)
|
||||
ex = DataFlow.normedges(ex)
|
||||
hidden = intersect((b.args[1] for b in ex.args), params)
|
||||
edges = Dict(h => gensym("edge") for h in hidden)
|
||||
declared = Dict(h => false for h in hidden)
|
||||
|
|
Loading…
Reference in New Issue