param collection
This commit is contained in:
parent
1179269355
commit
0ce8c0cee4
@ -16,6 +16,9 @@ include("Tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
export track, back!
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
||||
include("utils.jl")
|
||||
|
||||
include("compiler/Compiler.jl")
|
||||
|
@ -8,6 +8,8 @@ end
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
|
||||
Optimise.children(c::Chain) = c.layers
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
|
||||
Compiler.graph(s::Chain) =
|
||||
@ -32,6 +34,8 @@ end
|
||||
Linear(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||
Linear(σ, track(init(out, in)), track(init(out)))
|
||||
|
||||
Optimise.children(d::Linear) = (d.W, d.b)
|
||||
|
||||
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
|
||||
|
||||
function Base.show(io::IO, l::Linear)
|
||||
|
5
src/optimise/Optimise.jl
Normal file
5
src/optimise/Optimise.jl
Normal file
@ -0,0 +1,5 @@
|
||||
module Optimise
|
||||
|
||||
include("params.jl")
|
||||
|
||||
end
|
14
src/optimise/params.jl
Normal file
14
src/optimise/params.jl
Normal file
@ -0,0 +1,14 @@
|
||||
children(x) = ()
|
||||
|
||||
using ..Tracker.TrackedArray
|
||||
using DataFlow: OSet
|
||||
|
||||
params(ps, p::TrackedArray) = push!(ps, p)
|
||||
|
||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||
|
||||
function params(m)
|
||||
ps = OSet()
|
||||
params(ps, m)
|
||||
return collect(ps)
|
||||
end
|
Loading…
Reference in New Issue
Block a user