fix tests / deprecations

This commit is contained in:
Mike J Innes 2017-08-22 18:04:10 +01:00
parent eb584b066e
commit 5d6e8e2777

View File

@ -16,7 +16,7 @@ Affine(in::Integer, out::Integer; init = Flux.initn) =
first first
second second
function (x) function (x)
l1 = σ(first(x)) l1 = σ.(first(x))
l2 = softmax(second(l1)) l2 = softmax(second(l1))
end end
end end
@ -25,7 +25,7 @@ end
Wxy; Wyy; by Wxy; Wyy; by
y y
function (x) function (x)
y = tanh( x * Wxy .+ y{-1} * Wyy .+ by ) y = tanh.( x * Wxy .+ y{-1} * Wyy .+ by )
end end
end end
@ -51,8 +51,8 @@ end
let a1 = Affine(10, 20), a2 = Affine(20, 15) let a1 = Affine(10, 20), a2 = Affine(20, 15)
tlp = TLP(a1, a2) tlp = TLP(a1, a2)
@test tlp(xs) softmax(a2(σ(a1(xs)))) @test tlp(xs) softmax(a2(σ.(a1(xs))))
@test Flux.Compiler.interpmodel(tlp, xs) softmax(a2(σ(a1(xs)))) @test Flux.Compiler.interpmodel(tlp, xs) softmax(a2(σ.(a1(xs))))
end end
let tlp = TLP(Affine(10, 21), Affine(20, 15)) let tlp = TLP(Affine(10, 21), Affine(20, 15))
@ -78,7 +78,7 @@ end
r = Recurrent(10, 5) r = Recurrent(10, 5)
xs = [rand(1, 10) for _ = 1:3] xs = [rand(1, 10) for _ = 1:3]
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,)) _, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
@test ys[1] == tanh(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by) @test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
ru = Flux.Compiler.unroll(r, 3) ru = Flux.Compiler.unroll(r, 3)
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys) ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
end end