split basemodel/model
This commit is contained in:
parent
209c089e5a
commit
05d988dd74
|
@ -16,13 +16,15 @@ N = length(alphabet)
|
|||
|
||||
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||
|
||||
model = Chain(
|
||||
basemodel = Chain(
|
||||
Input(N),
|
||||
LSTM(N, 256),
|
||||
LSTM(256, 256),
|
||||
Affine(256, N),
|
||||
softmax)
|
||||
|
||||
model = Chain(basemodel, softmax)
|
||||
|
||||
m = tf(unroll(model, nunroll))
|
||||
|
||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||
|
@ -36,5 +38,5 @@ function sample(model, n, temp = 1)
|
|||
return string(s...)
|
||||
end
|
||||
|
||||
sample(model, 100)
|
||||
sample(basemodel, 100)
|
||||
```
|
||||
|
|
|
@ -13,13 +13,15 @@ N = length(alphabet)
|
|||
|
||||
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||
|
||||
model = Chain(
|
||||
basemodel = Chain(
|
||||
Input(N),
|
||||
LSTM(N, 256),
|
||||
LSTM(256, 256),
|
||||
Affine(256, N),
|
||||
softmax)
|
||||
|
||||
model = Chain(basemodel, softmax)
|
||||
|
||||
m = tf(unroll(model, nunroll))
|
||||
|
||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||
|
@ -33,4 +35,4 @@ function sample(model, n, temp = 1)
|
|||
return string(s...)
|
||||
end
|
||||
|
||||
sample(model, 100)
|
||||
sample(basemodel, 100)
|
||||
|
|
Loading…
Reference in New Issue