model storage notes

This commit is contained in:
Mike J Innes 2017-02-28 16:41:33 +00:00
parent 1dcdf66651
commit 5f1f2ebaa2
3 changed files with 29 additions and 2 deletions

View File

@ -14,7 +14,8 @@ makedocs(modules=[Flux],
"Debugging" => "models/debugging.md"],
"Other APIs" => [
"Batching" => "apis/batching.md",
"Backends" => "apis/backends.md"],
"Backends" => "apis/backends.md",
"Storing Models" => "apis/storage.md"],
"In Action" => [
"Logistic Regression" => "examples/logreg.md",
"Char RNN" => "examples/char-rnn.md"],

24
docs/src/apis/storage.md Normal file
View File

@ -0,0 +1,24 @@
# Loading and Save Models
```julia
model = Chain(Affine(10, 20), σ, Affine(20, 15), softmax)
```
Since models are just simple Julia data structures, it's very easy to save and load them using any of Julia's existing serialisation formats. For example, using Julia's built-in `serialize`:
```julia
open(io -> serialize(io, model), "model.jls", "w")
open(io -> deserialize(io), "model.jls")
```
One issue with `serialize` is that it doesn't promise compatibility between major Julia versions. For longer-term storage it's good to use a package like [JLD](https://github.com/JuliaIO/JLD.jl).
```julia
using JLD
@save "model.jld" model
@load "model.jld"
```
However, JLD will break for some models as functions are not supported on 0.5+. You can resolve that by checking out [this branch](https://github.com/JuliaIO/JLD.jl/pull/137).
Right now this is the only storage format Flux supports. In future Flux will support loading and saving other model formats (on an as-needed basis).

View File

@ -20,7 +20,7 @@ type Chain <: Model
end
end
@forward Chain.layers Base.getindex, Base.first, Base.last
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers)
@ -30,3 +30,5 @@ graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
shape(c::Chain, in) = c.shape
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])