model storage notes
This commit is contained in:
parent
1dcdf66651
commit
5f1f2ebaa2
|
@ -14,7 +14,8 @@ makedocs(modules=[Flux],
|
|||
"Debugging" => "models/debugging.md"],
|
||||
"Other APIs" => [
|
||||
"Batching" => "apis/batching.md",
|
||||
"Backends" => "apis/backends.md"],
|
||||
"Backends" => "apis/backends.md",
|
||||
"Storing Models" => "apis/storage.md"],
|
||||
"In Action" => [
|
||||
"Logistic Regression" => "examples/logreg.md",
|
||||
"Char RNN" => "examples/char-rnn.md"],
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Loading and Save Models
|
||||
|
||||
```julia
|
||||
model = Chain(Affine(10, 20), σ, Affine(20, 15), softmax)
|
||||
```
|
||||
|
||||
Since models are just simple Julia data structures, it's very easy to save and load them using any of Julia's existing serialisation formats. For example, using Julia's built-in `serialize`:
|
||||
|
||||
```julia
|
||||
open(io -> serialize(io, model), "model.jls", "w")
|
||||
open(io -> deserialize(io), "model.jls")
|
||||
```
|
||||
|
||||
One issue with `serialize` is that it doesn't promise compatibility between major Julia versions. For longer-term storage it's good to use a package like [JLD](https://github.com/JuliaIO/JLD.jl).
|
||||
|
||||
```julia
|
||||
using JLD
|
||||
@save "model.jld" model
|
||||
@load "model.jld"
|
||||
```
|
||||
|
||||
However, JLD will break for some models as functions are not supported on 0.5+. You can resolve that by checking out [this branch](https://github.com/JuliaIO/JLD.jl/pull/137).
|
||||
|
||||
Right now this is the only storage format Flux supports. In future Flux will support loading and saving other model formats (on an as-needed basis).
|
|
@ -20,7 +20,7 @@ type Chain <: Model
|
|||
end
|
||||
end
|
||||
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers)
|
||||
|
@ -30,3 +30,5 @@ graph(s::Chain) =
|
|||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||
|
||||
shape(c::Chain, in) = c.shape
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
|
||||
|
|
Loading…
Reference in New Issue