param collection
This commit is contained in:
parent
1179269355
commit
0ce8c0cee4
@ -16,6 +16,9 @@ include("Tracker/Tracker.jl")
|
|||||||
using .Tracker
|
using .Tracker
|
||||||
export track, back!
|
export track, back!
|
||||||
|
|
||||||
|
include("optimise/Optimise.jl")
|
||||||
|
using .Optimise
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
|
||||||
include("compiler/Compiler.jl")
|
include("compiler/Compiler.jl")
|
||||||
|
@ -8,6 +8,8 @@ end
|
|||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||||
@forward Chain.layers Base.start, Base.next, Base.done
|
@forward Chain.layers Base.start, Base.next, Base.done
|
||||||
|
|
||||||
|
Optimise.children(c::Chain) = c.layers
|
||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
|
|
||||||
Compiler.graph(s::Chain) =
|
Compiler.graph(s::Chain) =
|
||||||
@ -32,6 +34,8 @@ end
|
|||||||
Linear(in::Integer, out::Integer, σ = identity; init = initn) =
|
Linear(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||||
Linear(σ, track(init(out, in)), track(init(out)))
|
Linear(σ, track(init(out, in)), track(init(out)))
|
||||||
|
|
||||||
|
Optimise.children(d::Linear) = (d.W, d.b)
|
||||||
|
|
||||||
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
|
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
|
||||||
|
|
||||||
function Base.show(io::IO, l::Linear)
|
function Base.show(io::IO, l::Linear)
|
||||||
|
5
src/optimise/Optimise.jl
Normal file
5
src/optimise/Optimise.jl
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
module Optimise
|
||||||
|
|
||||||
|
include("params.jl")
|
||||||
|
|
||||||
|
end
|
14
src/optimise/params.jl
Normal file
14
src/optimise/params.jl
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
children(x) = ()
|
||||||
|
|
||||||
|
using ..Tracker.TrackedArray
|
||||||
|
using DataFlow: OSet
|
||||||
|
|
||||||
|
params(ps, p::TrackedArray) = push!(ps, p)
|
||||||
|
|
||||||
|
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||||
|
|
||||||
|
function params(m)
|
||||||
|
ps = OSet()
|
||||||
|
params(ps, m)
|
||||||
|
return collect(ps)
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user