generic tree functions

This commit is contained in:
Mike J Innes 2017-09-27 21:11:21 +01:00
parent 2ec8401d2c
commit 4bafa2b374
6 changed files with 35 additions and 28 deletions

View File

@ -21,6 +21,7 @@ using .Optimise
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")
include("tree.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/basic.jl") include("layers/basic.jl")

View File

@ -22,7 +22,8 @@ end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! @forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done @forward Chain.layers Base.start, Base.next, Base.done
Optimise.children(c::Chain) = c.layers children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
@ -53,7 +54,7 @@ end
Dense(in::Integer, out::Integer, σ = identity; init = initn) = Dense(in::Integer, out::Integer, σ = identity; init = initn) =
Dense(σ, param(init(out, in)), param(init(out))) Dense(σ, param(init(out, in)), param(init(out)))
Optimise.children(d::Dense) = (d.W, d.b) treelike(Dense)
(a::Dense)(x) = a.σ.(a.W*x .+ a.b) (a::Dense)(x) = a.σ.(a.W*x .+ a.b)

View File

@ -16,7 +16,7 @@ function (m::Recur)(xs...)
return y return y
end end
Optimise.children(m::Recur) = (m.cell,) treelike(Recur)
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
@ -24,7 +24,7 @@ _truncate(x::AbstractArray) = x
_truncate(x::TrackedArray) = x.data _truncate(x::TrackedArray) = x.data
_truncate(x::Tuple) = _truncate.(x) _truncate(x::Tuple) = _truncate.(x)
truncate!(m) = foreach(truncate!, Optimise.children(m)) truncate!(m) = foreach(truncate!, children(m))
truncate!(m::Recur) = (m.state = _truncate(m.state)) truncate!(m::Recur) = (m.state = _truncate(m.state))
# Vanilla RNN # Vanilla RNN
@ -44,7 +44,7 @@ end
hidden(m::RNNCell) = m.h hidden(m::RNNCell) = m.h
Optimise.children(m::RNNCell) = (m.d, m.h) treelike(RNNCell)
function Base.show(io::IO, m::RNNCell) function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")") print(io, "RNNCell(", m.d, ")")
@ -82,8 +82,7 @@ end
hidden(m::LSTMCell) = (m.h, m.c) hidden(m::LSTMCell) = (m.h, m.c)
Optimise.children(m::LSTMCell) = treelike(LSTMCell)
(m.forget, m.input, m.output, m.cell, m.h, m.c)
Base.show(io::IO, m::LSTMCell) = Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(", print(io, "LSTMCell(",

View File

@ -3,15 +3,19 @@ module Optimise
export update!, params, train!, export update!, params, train!,
SGD SGD
include("params.jl") struct Param{T}
x::T
Δ::T
end
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
include("optimisers.jl") include("optimisers.jl")
include("interface.jl") include("interface.jl")
include("train.jl") include("train.jl")
using Flux.Tracker: TrackedArray using Flux.Tracker: TrackedArray
params(ps, p::TrackedArray) = push!(ps, p)
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[]) Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
end end

View File

@ -1,18 +0,0 @@
using DataFlow: OSet
children(x) = ()
params(ps, m) = foreach(m -> params(ps, m), children(m))
function params(m)
ps = OSet()
params(ps, m)
return collect(ps)
end
struct Param{T}
x::T
Δ::T
end
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))

20
src/tree.jl Normal file
View File

@ -0,0 +1,20 @@
children(x) = ()
mapchildren(f, x) = x
function treelike(T, fs = fieldnames(T))
@eval begin
children(x::$T) = ($([:(x.$f) for f in fs]...),)
mapchildren(f, x::$T) = $T(f.(children(x))...)
end
end
using DataFlow: OSet
params(ps, p::AbstractArray) = push!(ps, p)
params(ps, m) = foreach(m -> params(ps, m), children(m))
function params(m)
ps = OSet()
params(ps, m)
return collect(ps)
end