remove params from compiler
This commit is contained in:
parent
c7a07562d0
commit
a581856954
@ -1,6 +1,5 @@
|
||||
import DataFlow: cse
|
||||
using MacroTools: @q, @>
|
||||
import ..Flux: Param, param, state
|
||||
|
||||
graph(m) = nothing
|
||||
|
||||
@ -28,34 +27,18 @@ end
|
||||
|
||||
function build_type(T, params)
|
||||
@esc T
|
||||
ex = quote
|
||||
type $T
|
||||
:(type $T
|
||||
$(params...)
|
||||
end
|
||||
end
|
||||
if any(x->isexpr(x, Symbol), params)
|
||||
push!(ex.args,
|
||||
:($T($(map(x->isexpr(x, Symbol) ? :($x::AbstractArray) : x, params)...)) =
|
||||
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
|
||||
end
|
||||
ex
|
||||
end
|
||||
|
||||
function deref_params(v)
|
||||
map(v) do x
|
||||
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
function build_forward(body, args)
|
||||
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
||||
applylines(syntax(cse(deref_params(body))))
|
||||
applylines(syntax(cse(body)))
|
||||
end
|
||||
|
||||
import Lazy: groupby
|
||||
|
||||
reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v)
|
||||
|
||||
# TODO: type hints for parameters
|
||||
|
||||
function process_type(ex)
|
||||
|
@ -17,7 +17,7 @@ function interp(ctx, f, xs...)
|
||||
g = graph(f)
|
||||
g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph")
|
||||
@icatch(ctx, g ≠ nothing ?
|
||||
interpret(ctx, reifyparams(g), xs...) :
|
||||
interpret(ctx, g, xs...) :
|
||||
f(xs...))
|
||||
end
|
||||
|
||||
|
@ -9,7 +9,7 @@ mutable struct Stateful
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
||||
Stateful(model, ss) = Stateful(model, ss, ss, ss)
|
||||
|
||||
function Base.show(io::IO, m::Stateful)
|
||||
print(io, "Stateful(")
|
||||
|
@ -29,13 +29,13 @@ syntax(x) = syntax(graph(x))
|
||||
xs = randn(1, 10)
|
||||
d = Affine(10, 20)
|
||||
|
||||
@test d(xs) ≈ (xs*d.W.x + d.b.x)
|
||||
@test d(xs) ≈ (xs*d.W + d.b)
|
||||
|
||||
d1 = @net x -> x * d.W + d.b
|
||||
|
||||
let
|
||||
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
||||
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
||||
@test isa(x, DataFlow.Input) && W isa Array && b isa Array
|
||||
end
|
||||
|
||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||
@ -66,8 +66,8 @@ end
|
||||
@testset "RNN unrolling" begin
|
||||
r = Recurrent(10, 5)
|
||||
xs = [rand(1, 10) for _ = 1:3]
|
||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
|
||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
||||
@test ys[1] == tanh(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
||||
ru = Flux.Compiler.unroll(r, 3)
|
||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user