update treelike
This commit is contained in:
parent
d782b33701
commit
70718e7a64
|
@ -211,7 +211,7 @@ m(5) # => 26
|
|||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||
|
||||
```julia
|
||||
Flux.treelike(Affine)
|
||||
Flux.@treelike Affine
|
||||
```
|
||||
|
||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||
|
|
|
@ -4,7 +4,7 @@ module Flux
|
|||
|
||||
# Zero Flux Given
|
||||
|
||||
using Juno, Requires, Reexport, StatsBase
|
||||
using MacroTools, Juno, Requires, Reexport, StatsBase
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
|
|
|
@ -73,7 +73,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
|
|||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||
end
|
||||
|
||||
treelike(Dense)
|
||||
@treelike Dense
|
||||
|
||||
function (a::Dense)(x)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
|
@ -104,7 +104,7 @@ end
|
|||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
|
||||
treelike(Diagonal)
|
||||
@treelike Diagonal
|
||||
|
||||
function (a::Diagonal)(x)
|
||||
α, β = a.α, a.β
|
||||
|
|
|
@ -35,7 +35,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
|
|||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
Flux.treelike(Conv)
|
||||
@treelike Conv
|
||||
|
||||
function (c::Conv)(x)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
|
|
|
@ -58,7 +58,7 @@ end
|
|||
LayerNorm(h::Integer) =
|
||||
LayerNorm(Diagonal(h))
|
||||
|
||||
treelike(LayerNorm)
|
||||
@treelike LayerNorm
|
||||
|
||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ function (m::Recur)(xs...)
|
|||
return y
|
||||
end
|
||||
|
||||
treelike(Recur, (:cell, :init))
|
||||
@treelike Recur cell, init
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
|
@ -94,7 +94,7 @@ end
|
|||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
treelike(RNNCell)
|
||||
@treelike RNNCell
|
||||
|
||||
function Base.show(io::IO, l::RNNCell)
|
||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||
|
@ -143,7 +143,7 @@ end
|
|||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
treelike(LSTMCell)
|
||||
@treelike LSTMCell
|
||||
|
||||
Base.show(io::IO, l::LSTMCell) =
|
||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||
|
@ -184,7 +184,7 @@ end
|
|||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
treelike(GRUCell)
|
||||
@treelike GRUCell
|
||||
|
||||
Base.show(io::IO, l::GRUCell) =
|
||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||
|
|
|
@ -7,13 +7,24 @@ mapchildren(f, x) = x
|
|||
children(x::Tuple) = x
|
||||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
@eval current_module() begin
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||
end
|
||||
end
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
|
||||
treelike(Base._current_module(), T, fs)
|
||||
end
|
||||
|
||||
macro treelike(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
function mapleaves(f, x; cache = IdDict())
|
||||
|
|
Loading…
Reference in New Issue