From 0ce8c0cee42c3aabc2e816f26a21e36e8b050e19 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 22 Aug 2017 17:13:03 +0100 Subject: [PATCH] param collection --- src/Flux.jl | 3 +++ src/layers/basic.jl | 4 ++++ src/optimise/Optimise.jl | 5 +++++ src/optimise/params.jl | 14 ++++++++++++++ 4 files changed, 26 insertions(+) create mode 100644 src/optimise/Optimise.jl create mode 100644 src/optimise/params.jl diff --git a/src/Flux.jl b/src/Flux.jl index fbc4d4af..d65aac32 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -16,6 +16,9 @@ include("Tracker/Tracker.jl") using .Tracker export track, back! +include("optimise/Optimise.jl") +using .Optimise + include("utils.jl") include("compiler/Compiler.jl") diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a6a8bd62..911a2ce4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -8,6 +8,8 @@ end @forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! @forward Chain.layers Base.start, Base.next, Base.done +Optimise.children(c::Chain) = c.layers + (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) Compiler.graph(s::Chain) = @@ -32,6 +34,8 @@ end Linear(in::Integer, out::Integer, σ = identity; init = initn) = Linear(σ, track(init(out, in)), track(init(out))) +Optimise.children(d::Linear) = (d.W, d.b) + (a::Linear)(x) = a.σ.(a.W*x .+ a.b) function Base.show(io::IO, l::Linear) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl new file mode 100644 index 00000000..cc64bb31 --- /dev/null +++ b/src/optimise/Optimise.jl @@ -0,0 +1,5 @@ +module Optimise + +include("params.jl") + +end diff --git a/src/optimise/params.jl b/src/optimise/params.jl new file mode 100644 index 00000000..e3fff208 --- /dev/null +++ b/src/optimise/params.jl @@ -0,0 +1,14 @@ +children(x) = () + +using ..Tracker.TrackedArray +using DataFlow: OSet + +params(ps, p::TrackedArray) = push!(ps, p) + +params(ps, m) = foreach(m -> params(ps, m), children(m)) + +function params(m) + ps = OSet() + params(ps, m) + return collect(ps) +end