update treelike

This commit is contained in:
Mike J Innes 2018-07-12 22:43:11 +01:00
parent d782b33701
commit 70718e7a64
7 changed files with 23 additions and 12 deletions

View File

@ -211,7 +211,7 @@ m(5) # => 26
Flux provides a set of helpers for custom layers, which you can enable by calling
```julia
Flux.treelike(Affine)
Flux.@treelike Affine
```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -4,7 +4,7 @@ module Flux
# Zero Flux Given
using Juno, Requires, Reexport, StatsBase
using MacroTools, Juno, Requires, Reexport, StatsBase
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv,

View File

@ -73,7 +73,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
return Dense(param(initW(out, in)), param(initb(out)), σ)
end
treelike(Dense)
@treelike Dense
function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ
@ -104,7 +104,7 @@ end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in)))
treelike(Diagonal)
@treelike Diagonal
function (a::Diagonal)(x)
α, β = a.α, a.β

View File

@ -35,7 +35,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
Flux.treelike(Conv)
@treelike Conv
function (c::Conv)(x)
# TODO: breaks gpu broadcast :(

View File

@ -58,7 +58,7 @@ end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
treelike(LayerNorm)
@treelike LayerNorm
(a::LayerNorm)(x) = a.diag(normalise(x))

View File

@ -38,7 +38,7 @@ function (m::Recur)(xs...)
return y
end
treelike(Recur, (:cell, :init))
@treelike Recur cell, init
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
@ -94,7 +94,7 @@ end
hidden(m::RNNCell) = m.h
treelike(RNNCell)
@treelike RNNCell
function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
@ -143,7 +143,7 @@ end
hidden(m::LSTMCell) = (m.h, m.c)
treelike(LSTMCell)
@treelike LSTMCell
Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
@ -184,7 +184,7 @@ end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
@treelike GRUCell
Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

View File

@ -7,13 +7,24 @@ mapchildren(f, x) = x
children(x::Tuple) = x
mapchildren(f, x::Tuple) = map(f, x)
function treelike(T, fs = fieldnames(T))
@eval current_module() begin
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
end
end
function treelike(T, fs = fieldnames(T))
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
treelike(Base._current_module(), T, fs)
end
macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
end
isleaf(x) = isempty(children(x))
function mapleaves(f, x; cache = IdDict())