param collection

This commit is contained in:
Mike J Innes 2017-08-22 17:13:03 +01:00
parent 1179269355
commit 0ce8c0cee4
4 changed files with 26 additions and 0 deletions

View File

@ -16,6 +16,9 @@ include("Tracker/Tracker.jl")
using .Tracker
export track, back!
include("optimise/Optimise.jl")
using .Optimise
include("utils.jl")
include("compiler/Compiler.jl")

View File

@ -8,6 +8,8 @@ end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done
Optimise.children(c::Chain) = c.layers
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
Compiler.graph(s::Chain) =
@ -32,6 +34,8 @@ end
Linear(in::Integer, out::Integer, σ = identity; init = initn) =
Linear(σ, track(init(out, in)), track(init(out)))
Optimise.children(d::Linear) = (d.W, d.b)
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
function Base.show(io::IO, l::Linear)

5
src/optimise/Optimise.jl Normal file
View File

@ -0,0 +1,5 @@
module Optimise
include("params.jl")
end

14
src/optimise/params.jl Normal file
View File

@ -0,0 +1,14 @@
children(x) = ()
using ..Tracker.TrackedArray
using DataFlow: OSet
params(ps, p::TrackedArray) = push!(ps, p)
params(ps, m) = foreach(m -> params(ps, m), children(m))
function params(m)
ps = OSet()
params(ps, m)
return collect(ps)
end