From a5818569547ea2b701839d2db6c34f6959c8424d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 19 Aug 2017 20:38:20 +0100 Subject: [PATCH] remove params from compiler --- src/compiler/code.jl | 23 +++-------------------- src/compiler/interp.jl | 2 +- src/compiler/loops.jl | 2 +- test/compiler.jl | 8 ++++---- 4 files changed, 9 insertions(+), 26 deletions(-) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 027f862c..b873547a 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -1,6 +1,5 @@ import DataFlow: cse using MacroTools: @q, @> -import ..Flux: Param, param, state graph(m) = nothing @@ -28,34 +27,18 @@ end function build_type(T, params) @esc T - ex = quote - type $T + :(type $T $(params...) - end - end - if any(x->isexpr(x, Symbol), params) - push!(ex.args, - :($T($(map(x->isexpr(x, Symbol) ? :($x::AbstractArray) : x, params)...)) = - $T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...)))) - end - ex -end - -function deref_params(v) - map(v) do x - @capture(x, self.p_) ? :(Flux.state(self.$p)) : x - end + end) end function build_forward(body, args) iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph")) - applylines(syntax(cse(deref_params(body)))) + applylines(syntax(cse(body))) end import Lazy: groupby -reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v) - # TODO: type hints for parameters function process_type(ex) diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index 85618127..d9759260 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -17,7 +17,7 @@ function interp(ctx, f, xs...) g = graph(f) g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph") @icatch(ctx, g ≠ nothing ? - interpret(ctx, reifyparams(g), xs...) : + interpret(ctx, g, xs...) : f(xs...)) end diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index dc228526..6ccb9d39 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -9,7 +9,7 @@ mutable struct Stateful ostate::Vector{Any} end -Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss)) +Stateful(model, ss) = Stateful(model, ss, ss, ss) function Base.show(io::IO, m::Stateful) print(io, "Stateful(") diff --git a/test/compiler.jl b/test/compiler.jl index 7af6bad5..b26fd288 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -29,13 +29,13 @@ syntax(x) = syntax(graph(x)) xs = randn(1, 10) d = Affine(10, 20) -@test d(xs) ≈ (xs*d.W.x + d.b.x) +@test d(xs) ≈ (xs*d.W + d.b) d1 = @net x -> x * d.W + d.b let @capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_)))) - @test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param) + @test isa(x, DataFlow.Input) && W isa Array && b isa Array end let a1 = Affine(10, 20), a2 = Affine(20, 15) @@ -66,8 +66,8 @@ end @testset "RNN unrolling" begin r = Recurrent(10, 5) xs = [rand(1, 10) for _ = 1:3] - _, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,)) - @test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x) + _, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,)) + @test ys[1] == tanh(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by) ru = Flux.Compiler.unroll(r, 3) ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys) end