remove params from compiler
This commit is contained in:
parent
c7a07562d0
commit
a581856954
@ -1,6 +1,5 @@
|
|||||||
import DataFlow: cse
|
import DataFlow: cse
|
||||||
using MacroTools: @q, @>
|
using MacroTools: @q, @>
|
||||||
import ..Flux: Param, param, state
|
|
||||||
|
|
||||||
graph(m) = nothing
|
graph(m) = nothing
|
||||||
|
|
||||||
@ -28,34 +27,18 @@ end
|
|||||||
|
|
||||||
function build_type(T, params)
|
function build_type(T, params)
|
||||||
@esc T
|
@esc T
|
||||||
ex = quote
|
:(type $T
|
||||||
type $T
|
|
||||||
$(params...)
|
$(params...)
|
||||||
end
|
end)
|
||||||
end
|
|
||||||
if any(x->isexpr(x, Symbol), params)
|
|
||||||
push!(ex.args,
|
|
||||||
:($T($(map(x->isexpr(x, Symbol) ? :($x::AbstractArray) : x, params)...)) =
|
|
||||||
$T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...))))
|
|
||||||
end
|
|
||||||
ex
|
|
||||||
end
|
|
||||||
|
|
||||||
function deref_params(v)
|
|
||||||
map(v) do x
|
|
||||||
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_forward(body, args)
|
function build_forward(body, args)
|
||||||
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
||||||
applylines(syntax(cse(deref_params(body))))
|
applylines(syntax(cse(body)))
|
||||||
end
|
end
|
||||||
|
|
||||||
import Lazy: groupby
|
import Lazy: groupby
|
||||||
|
|
||||||
reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v)
|
|
||||||
|
|
||||||
# TODO: type hints for parameters
|
# TODO: type hints for parameters
|
||||||
|
|
||||||
function process_type(ex)
|
function process_type(ex)
|
||||||
|
@ -17,7 +17,7 @@ function interp(ctx, f, xs...)
|
|||||||
g = graph(f)
|
g = graph(f)
|
||||||
g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph")
|
g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph")
|
||||||
@icatch(ctx, g ≠ nothing ?
|
@icatch(ctx, g ≠ nothing ?
|
||||||
interpret(ctx, reifyparams(g), xs...) :
|
interpret(ctx, g, xs...) :
|
||||||
f(xs...))
|
f(xs...))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ mutable struct Stateful
|
|||||||
ostate::Vector{Any}
|
ostate::Vector{Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
Stateful(model, ss) = Stateful(model, ss, ss, ss)
|
||||||
|
|
||||||
function Base.show(io::IO, m::Stateful)
|
function Base.show(io::IO, m::Stateful)
|
||||||
print(io, "Stateful(")
|
print(io, "Stateful(")
|
||||||
|
@ -29,13 +29,13 @@ syntax(x) = syntax(graph(x))
|
|||||||
xs = randn(1, 10)
|
xs = randn(1, 10)
|
||||||
d = Affine(10, 20)
|
d = Affine(10, 20)
|
||||||
|
|
||||||
@test d(xs) ≈ (xs*d.W.x + d.b.x)
|
@test d(xs) ≈ (xs*d.W + d.b)
|
||||||
|
|
||||||
d1 = @net x -> x * d.W + d.b
|
d1 = @net x -> x * d.W + d.b
|
||||||
|
|
||||||
let
|
let
|
||||||
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
||||||
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
@test isa(x, DataFlow.Input) && W isa Array && b isa Array
|
||||||
end
|
end
|
||||||
|
|
||||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||||
@ -66,8 +66,8 @@ end
|
|||||||
@testset "RNN unrolling" begin
|
@testset "RNN unrolling" begin
|
||||||
r = Recurrent(10, 5)
|
r = Recurrent(10, 5)
|
||||||
xs = [rand(1, 10) for _ = 1:3]
|
xs = [rand(1, 10) for _ = 1:3]
|
||||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y.x,))
|
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
||||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
@test ys[1] == tanh(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
||||||
ru = Flux.Compiler.unroll(r, 3)
|
ru = Flux.Compiler.unroll(r, 3)
|
||||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user