rm ambiguous function
This commit is contained in:
parent
121af0579a
commit
f7f8124a78
@ -1,5 +1,3 @@
|
|||||||
using Flux: stateless
|
|
||||||
|
|
||||||
function apply(model, xs, state)
|
function apply(model, xs, state)
|
||||||
ys = similar(xs, 0)
|
ys = similar(xs, 0)
|
||||||
for x in xs
|
for x in xs
|
||||||
@ -12,7 +10,7 @@ end
|
|||||||
@testset "RNN unrolling" begin
|
@testset "RNN unrolling" begin
|
||||||
r = Recurrent(10, 5)
|
r = Recurrent(10, 5)
|
||||||
xs = [rand(1, 10) for _ = 1:3]
|
xs = [rand(1, 10) for _ = 1:3]
|
||||||
_, ys = apply(stateless(unroll1(r)), xs, (r.y.x,))
|
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
|
||||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||||
ru = unroll(r, 3)
|
ru = unroll(r, 3)
|
||||||
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
|
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
|
||||||
|
Loading…
Reference in New Issue
Block a user