better alternative to basemodel
This commit is contained in:
parent
5f1f2ebaa2
commit
4d4979b401
@ -16,15 +16,13 @@ N = length(alphabet)
|
|||||||
|
|
||||||
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||||
|
|
||||||
basemodel = Chain(
|
model = Chain(
|
||||||
Input(N),
|
Input(N),
|
||||||
LSTM(N, 256),
|
LSTM(N, 256),
|
||||||
LSTM(256, 256),
|
LSTM(256, 256),
|
||||||
Affine(256, N),
|
Affine(256, N),
|
||||||
softmax)
|
softmax)
|
||||||
|
|
||||||
model = Chain(basemodel, softmax)
|
|
||||||
|
|
||||||
m = tf(unroll(model, nunroll))
|
m = tf(unroll(model, nunroll))
|
||||||
|
|
||||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||||
@ -38,5 +36,5 @@ function sample(model, n, temp = 1)
|
|||||||
return string(s...)
|
return string(s...)
|
||||||
end
|
end
|
||||||
|
|
||||||
sample(basemodel, 100)
|
sample(model[1:end-1], 100)
|
||||||
```
|
```
|
||||||
|
@ -13,15 +13,13 @@ N = length(alphabet)
|
|||||||
|
|
||||||
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||||
|
|
||||||
basemodel = Chain(
|
model = Chain(
|
||||||
Input(N),
|
Input(N),
|
||||||
LSTM(N, 256),
|
LSTM(N, 256),
|
||||||
LSTM(256, 256),
|
LSTM(256, 256),
|
||||||
Affine(256, N),
|
Affine(256, N),
|
||||||
softmax)
|
softmax)
|
||||||
|
|
||||||
model = Chain(basemodel, softmax)
|
|
||||||
|
|
||||||
m = tf(unroll(model, nunroll))
|
m = tf(unroll(model, nunroll))
|
||||||
|
|
||||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||||
@ -35,4 +33,4 @@ function sample(model, n, temp = 1)
|
|||||||
return string(s...)
|
return string(s...)
|
||||||
end
|
end
|
||||||
|
|
||||||
sample(basemodel, 100)
|
sample(model[1:end-1], 100)
|
||||||
|
@ -31,4 +31,4 @@ graph(s::Chain) =
|
|||||||
|
|
||||||
shape(c::Chain, in) = c.shape
|
shape(c::Chain, in) = c.shape
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
Loading…
Reference in New Issue
Block a user