From 4bafa2b374a987db56b9055319f8ded8143b836c Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 27 Sep 2017 21:11:21 +0100 Subject: [PATCH] generic tree functions --- src/Flux.jl | 1 + src/layers/basic.jl | 5 +++-- src/layers/recurrent.jl | 9 ++++----- src/optimise/Optimise.jl | 10 +++++++--- src/optimise/params.jl | 18 ------------------ src/tree.jl | 20 ++++++++++++++++++++ 6 files changed, 35 insertions(+), 28 deletions(-) delete mode 100644 src/optimise/params.jl create mode 100644 src/tree.jl diff --git a/src/Flux.jl b/src/Flux.jl index 8c88d229..ba9a6327 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,6 +21,7 @@ using .Optimise include("utils.jl") include("onehot.jl") +include("tree.jl") include("layers/stateless.jl") include("layers/basic.jl") diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 71de15fe..37c1b787 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -22,7 +22,8 @@ end @forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! @forward Chain.layers Base.start, Base.next, Base.done -Optimise.children(c::Chain) = c.layers +children(c::Chain) = c.layers +mapchildren(f, c::Chain) = Chain(f.(c.layers)...) (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) @@ -53,7 +54,7 @@ end Dense(in::Integer, out::Integer, σ = identity; init = initn) = Dense(σ, param(init(out, in)), param(init(out))) -Optimise.children(d::Dense) = (d.W, d.b) +treelike(Dense) (a::Dense)(x) = a.σ.(a.W*x .+ a.b) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 5d44e1bf..491209a0 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -16,7 +16,7 @@ function (m::Recur)(xs...) return y end -Optimise.children(m::Recur) = (m.cell,) +treelike(Recur) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -24,7 +24,7 @@ _truncate(x::AbstractArray) = x _truncate(x::TrackedArray) = x.data _truncate(x::Tuple) = _truncate.(x) -truncate!(m) = foreach(truncate!, Optimise.children(m)) +truncate!(m) = foreach(truncate!, children(m)) truncate!(m::Recur) = (m.state = _truncate(m.state)) # Vanilla RNN @@ -44,7 +44,7 @@ end hidden(m::RNNCell) = m.h -Optimise.children(m::RNNCell) = (m.d, m.h) +treelike(RNNCell) function Base.show(io::IO, m::RNNCell) print(io, "RNNCell(", m.d, ")") @@ -82,8 +82,7 @@ end hidden(m::LSTMCell) = (m.h, m.c) -Optimise.children(m::LSTMCell) = - (m.forget, m.input, m.output, m.cell, m.h, m.c) +treelike(LSTMCell) Base.show(io::IO, m::LSTMCell) = print(io, "LSTMCell(", diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 57c202eb..57956426 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,15 +3,19 @@ module Optimise export update!, params, train!, SGD -include("params.jl") +struct Param{T} + x::T + Δ::T +end + +Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x)) + include("optimisers.jl") include("interface.jl") include("train.jl") using Flux.Tracker: TrackedArray -params(ps, p::TrackedArray) = push!(ps, p) - Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[]) end diff --git a/src/optimise/params.jl b/src/optimise/params.jl deleted file mode 100644 index c5163dbe..00000000 --- a/src/optimise/params.jl +++ /dev/null @@ -1,18 +0,0 @@ -using DataFlow: OSet - -children(x) = () - -params(ps, m) = foreach(m -> params(ps, m), children(m)) - -function params(m) - ps = OSet() - params(ps, m) - return collect(ps) -end - -struct Param{T} - x::T - Δ::T -end - -convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x)) diff --git a/src/tree.jl b/src/tree.jl new file mode 100644 index 00000000..438685d5 --- /dev/null +++ b/src/tree.jl @@ -0,0 +1,20 @@ +children(x) = () +mapchildren(f, x) = x + +function treelike(T, fs = fieldnames(T)) + @eval begin + children(x::$T) = ($([:(x.$f) for f in fs]...),) + mapchildren(f, x::$T) = $T(f.(children(x))...) + end +end + +using DataFlow: OSet + +params(ps, p::AbstractArray) = push!(ps, p) +params(ps, m) = foreach(m -> params(ps, m), children(m)) + +function params(m) + ps = OSet() + params(ps, m) + return collect(ps) +end