remove params from compiler

This commit is contained in:
Mike J Innes 2017-08-19 20:38:20 +01:00
parent c7a07562d0
commit a581856954
4 changed files with 9 additions and 26 deletions

View File

@ -1,6 +1,5 @@
import DataFlow: cse
using MacroTools: @q, @>
import ..Flux: Param, param, state
graph(m) = nothing
@ -28,34 +27,18 @@ end
function build_type(T, params)
@esc T
ex = quote
type $T
:(type $T
$(params...)
end
end
if any(x->isexpr(x, Symbol), params)
push!(ex.args,
:($T($(map(x->isexpr(x, Symbol) ? :($x::AbstractArray) : x, params)...)) =
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
end
ex
end
function deref_params(v)
map(v) do x
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
end
end)
end
function build_forward(body, args)
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
applylines(syntax(cse(deref_params(body))))
applylines(syntax(cse(body)))
end
import Lazy: groupby
reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v)
# TODO: type hints for parameters
function process_type(ex)

View File

@ -17,7 +17,7 @@ function interp(ctx, f, xs...)
g = graph(f)
g nothing && iscyclic(g) && error("Can't interpret cyclic graph")
@icatch(ctx, g nothing ?
interpret(ctx, reifyparams(g), xs...) :
interpret(ctx, g, xs...) :
f(xs...))
end

View File

@ -9,7 +9,7 @@ mutable struct Stateful
ostate::Vector{Any}
end
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
Stateful(model, ss) = Stateful(model, ss, ss, ss)
function Base.show(io::IO, m::Stateful)
print(io, "Stateful(")

View File

@ -29,13 +29,13 @@ syntax(x) = syntax(graph(x))
xs = randn(1, 10)
d = Affine(10, 20)
@test d(xs) (xs*d.W.x + d.b.x)
@test d(xs) (xs*d.W + d.b)
d1 = @net x -> x * d.W + d.b
let
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
@test isa(x, DataFlow.Input) && W isa Array && b isa Array
end
let a1 = Affine(10, 20), a2 = Affine(20, 15)
@ -66,8 +66,8 @@ end
@testset "RNN unrolling" begin
r = Recurrent(10, 5)
xs = [rand(1, 10) for _ = 1:3]
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
@test ys[1] == tanh(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
ru = Flux.Compiler.unroll(r, 3)
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
end