2017-08-19 19:20:20 +00:00
|
|
|
|
using DataFlow, MacroTools
|
2017-09-06 22:58:55 +00:00
|
|
|
|
using Flux: stack, unsqueeze
|
|
|
|
|
using Flux.Compiler: @net, graph
|
2017-08-19 19:20:20 +00:00
|
|
|
|
using DataFlow: Line, Frame
|
|
|
|
|
|
2017-08-19 19:26:27 +00:00
|
|
|
|
@net type Affine
|
|
|
|
|
W
|
|
|
|
|
b
|
|
|
|
|
x -> x*W .+ b
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
Affine(in::Integer, out::Integer; init = Flux.initn) =
|
|
|
|
|
Affine(init(in, out), init(1, out))
|
|
|
|
|
|
2017-08-19 19:20:20 +00:00
|
|
|
|
@net type TLP
|
|
|
|
|
first
|
|
|
|
|
second
|
|
|
|
|
function (x)
|
2017-08-22 17:04:10 +00:00
|
|
|
|
l1 = σ.(first(x))
|
2017-08-19 19:20:20 +00:00
|
|
|
|
l2 = softmax(second(l1))
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2017-08-19 19:52:17 +00:00
|
|
|
|
@net type Recurrent
|
|
|
|
|
Wxy; Wyy; by
|
|
|
|
|
y
|
|
|
|
|
function (x)
|
2017-08-22 17:04:10 +00:00
|
|
|
|
y = tanh.( x * Wxy .+ y{-1} * Wyy .+ by )
|
2017-08-19 19:52:17 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
Recurrent(in, out; init = Flux.initn) =
|
|
|
|
|
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
|
|
|
|
|
|
2017-08-19 19:20:20 +00:00
|
|
|
|
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
|
|
|
|
syntax(x) = syntax(graph(x))
|
|
|
|
|
|
|
|
|
|
@testset "Compiler" begin
|
|
|
|
|
|
|
|
|
|
xs = randn(1, 10)
|
|
|
|
|
d = Affine(10, 20)
|
|
|
|
|
|
2017-08-19 19:38:20 +00:00
|
|
|
|
@test d(xs) ≈ (xs*d.W + d.b)
|
2017-08-19 19:20:20 +00:00
|
|
|
|
|
|
|
|
|
d1 = @net x -> x * d.W + d.b
|
|
|
|
|
|
|
|
|
|
let
|
|
|
|
|
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
2017-08-19 19:38:20 +00:00
|
|
|
|
@test isa(x, DataFlow.Input) && W isa Array && b isa Array
|
2017-08-19 19:20:20 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
|
|
|
|
tlp = TLP(a1, a2)
|
2017-08-22 17:04:10 +00:00
|
|
|
|
@test tlp(xs) ≈ softmax(a2(σ.(a1(xs))))
|
|
|
|
|
@test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ.(a1(xs))))
|
2017-08-19 19:20:20 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
|
|
|
|
e = try
|
|
|
|
|
Flux.Compiler.interpmodel(tlp, rand(1, 10))
|
|
|
|
|
catch e
|
|
|
|
|
e
|
|
|
|
|
end
|
|
|
|
|
@test e.trace[end].func == :TLP
|
2017-08-19 19:26:27 +00:00
|
|
|
|
@test e.trace[end-1].func == Symbol("Affine")
|
2017-08-19 19:20:20 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function apply(model, xs, state)
|
|
|
|
|
ys = similar(xs, 0)
|
|
|
|
|
for x in xs
|
|
|
|
|
state, y = model(state, x)
|
|
|
|
|
push!(ys, y)
|
|
|
|
|
end
|
|
|
|
|
state, ys
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "RNN unrolling" begin
|
|
|
|
|
r = Recurrent(10, 5)
|
|
|
|
|
xs = [rand(1, 10) for _ = 1:3]
|
2017-08-19 19:38:20 +00:00
|
|
|
|
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
2017-08-22 17:04:10 +00:00
|
|
|
|
@test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
2017-08-19 19:20:20 +00:00
|
|
|
|
ru = Flux.Compiler.unroll(r, 3)
|
2017-09-06 22:58:55 +00:00
|
|
|
|
ru(unsqueeze(stack(squeeze.(xs, 1), 1), 1))[1] == squeeze.(ys, 1)
|
2017-08-19 19:20:20 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
end
|