2017-02-23 22:40:07 +00:00
|
|
|
|
@net type TLP
|
|
|
|
|
first
|
|
|
|
|
second
|
|
|
|
|
function (x)
|
|
|
|
|
l1 = σ(first(x))
|
|
|
|
|
l2 = softmax(second(l1))
|
|
|
|
|
end
|
|
|
|
|
end
|
2016-12-15 22:31:39 +00:00
|
|
|
|
|
2017-03-06 16:12:03 +00:00
|
|
|
|
@net type Multi
|
|
|
|
|
W
|
|
|
|
|
V
|
|
|
|
|
x -> (x*W, x*V)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
Multi(in::Integer, out::Integer) =
|
|
|
|
|
Multi(randn(in, out), randn(in, out))
|
|
|
|
|
|
2017-02-23 22:40:07 +00:00
|
|
|
|
@testset "Basics" begin
|
|
|
|
|
|
|
|
|
|
xs = randn(10)
|
2016-12-15 22:31:39 +00:00
|
|
|
|
d = Affine(10, 20)
|
|
|
|
|
|
2017-01-30 17:08:45 +00:00
|
|
|
|
@test d(xs) ≈ (xs'*d.W.x + d.b.x)[1,:]
|
2016-12-15 23:11:35 +00:00
|
|
|
|
|
2016-12-15 23:13:20 +00:00
|
|
|
|
let
|
2016-12-26 12:19:04 +00:00
|
|
|
|
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
2017-02-21 15:46:38 +00:00
|
|
|
|
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
2016-12-15 23:11:35 +00:00
|
|
|
|
end
|
2017-02-01 12:53:28 +00:00
|
|
|
|
|
|
|
|
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
|
|
|
|
tlp = TLP(a1, a2)
|
|
|
|
|
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
2017-02-24 14:38:17 +00:00
|
|
|
|
@test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
2017-02-01 14:56:38 +00:00
|
|
|
|
@test Flux.infer(tlp, (1, 10)) == (1,15)
|
2017-02-01 12:53:28 +00:00
|
|
|
|
end
|
2017-02-21 16:34:15 +00:00
|
|
|
|
|
|
|
|
|
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
|
|
|
|
e = try
|
2017-02-24 14:38:17 +00:00
|
|
|
|
Flux.interpmodel(tlp, rand(10))
|
2017-02-21 16:34:15 +00:00
|
|
|
|
catch e
|
|
|
|
|
e
|
|
|
|
|
end
|
|
|
|
|
@test e.trace[end].func == :TLP
|
|
|
|
|
@test e.trace[end-1].func == Symbol("Flux.Affine")
|
|
|
|
|
end
|
2017-02-23 22:40:07 +00:00
|
|
|
|
|
2017-03-06 16:12:03 +00:00
|
|
|
|
let m = Multi(10, 15)
|
|
|
|
|
x = rand(10)
|
|
|
|
|
@test all(isapprox.(m(x), (m.W.x' * x, m.V.x' * x)))
|
|
|
|
|
end
|
|
|
|
|
|
2017-02-23 22:40:07 +00:00
|
|
|
|
end
|