internal rename
This commit is contained in:
parent
b951377426
commit
cabb81e30b
@ -25,16 +25,16 @@ loss(x, y) # ~ 3
|
|||||||
|
|
||||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||||
|
|
||||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
|
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
d = Dense(10, 5, σ)
|
d = Dense(10, 5, σ)
|
||||||
d = mapleaves(cu, d)
|
d = fmap(cu, d)
|
||||||
d.W # Tracked CuArray
|
d.W # Tracked CuArray
|
||||||
d(cu(rand(10))) # CuArray output
|
d(cu(rand(10))) # CuArray output
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
m = mapleaves(cu, m)
|
m = fmap(cu, m)
|
||||||
d(cu(rand(10)))
|
d(cu(rand(10)))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -215,7 +215,7 @@ m(5) # => 26
|
|||||||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
Flux.@treelike Affine
|
Flux.@functor Affine
|
||||||
```
|
```
|
||||||
|
|
||||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||||
|
@ -11,7 +11,7 @@ export gradient
|
|||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
SkipConnection, params, mapleaves, cpu, gpu, f32, f64
|
SkipConnection, params, fmap, cpu, gpu, f32, f64
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
|
@ -16,7 +16,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T))
|
|||||||
end
|
end
|
||||||
|
|
||||||
function functorm(T, fs = nothing)
|
function functorm(T, fs = nothing)
|
||||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
|
||||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||||
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
||||||
end
|
end
|
||||||
@ -61,8 +61,6 @@ macro treelike(args...)
|
|||||||
end
|
end
|
||||||
mapleaves(f, x) = fmap(f, x)
|
mapleaves(f, x) = fmap(f, x)
|
||||||
|
|
||||||
# function params
|
|
||||||
|
|
||||||
function loadparams!(m, xs)
|
function loadparams!(m, xs)
|
||||||
for (p, x) in zip(params(m), xs)
|
for (p, x) in zip(params(m), xs)
|
||||||
size(p) == size(x) ||
|
size(p) == size(x) ||
|
||||||
@ -73,7 +71,7 @@ end
|
|||||||
|
|
||||||
# CPU/GPU movement conveniences
|
# CPU/GPU movement conveniences
|
||||||
|
|
||||||
cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
cpu(m) = fmap(x -> adapt(Array, x), m)
|
||||||
|
|
||||||
const gpu_adaptor = if has_cuarrays()
|
const gpu_adaptor = if has_cuarrays()
|
||||||
CuArrays.cu
|
CuArrays.cu
|
||||||
@ -81,13 +79,13 @@ else
|
|||||||
identity
|
identity
|
||||||
end
|
end
|
||||||
|
|
||||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
gpu(x) = fmap(gpu_adaptor, x)
|
||||||
|
|
||||||
# Precision
|
# Precision
|
||||||
|
|
||||||
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
||||||
|
|
||||||
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
|
||||||
|
|
||||||
f32(m) = paramtype(Float32, m)
|
f32(m) = paramtype(Float32, m)
|
||||||
f64(m) = paramtype(Float64, m)
|
f64(m) = paramtype(Float64, m)
|
||||||
|
@ -91,7 +91,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
|
|||||||
return Dense(initW(out, in), initb(out), σ)
|
return Dense(initW(out, in), initb(out), σ)
|
||||||
end
|
end
|
||||||
|
|
||||||
@treelike Dense
|
@functor Dense
|
||||||
|
|
||||||
function (a::Dense)(x::AbstractArray)
|
function (a::Dense)(x::AbstractArray)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
@ -130,7 +130,7 @@ end
|
|||||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||||
Diagonal(initα(in), initβ(in))
|
Diagonal(initα(in), initβ(in))
|
||||||
|
|
||||||
@treelike Diagonal
|
@functor Diagonal
|
||||||
|
|
||||||
function (a::Diagonal)(x)
|
function (a::Diagonal)(x)
|
||||||
α, β = a.α, a.β
|
α, β = a.α, a.β
|
||||||
@ -183,7 +183,7 @@ function Maxout(f, n_alts)
|
|||||||
return Maxout(over)
|
return Maxout(over)
|
||||||
end
|
end
|
||||||
|
|
||||||
@treelike Maxout
|
@functor Maxout
|
||||||
|
|
||||||
function (mo::Maxout)(input::AbstractArray)
|
function (mo::Maxout)(input::AbstractArray)
|
||||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||||
@ -208,7 +208,7 @@ struct SkipConnection
|
|||||||
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
||||||
end
|
end
|
||||||
|
|
||||||
@treelike SkipConnection
|
@functor SkipConnection
|
||||||
|
|
||||||
function (skip::SkipConnection)(input)
|
function (skip::SkipConnection)(input)
|
||||||
#We apply the layers to the input and return the result of the application of the layers and the original input
|
#We apply the layers to the input and return the result of the application of the layers and the original input
|
||||||
|
@ -45,7 +45,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
|||||||
Conv(init(k..., ch...), zeros(ch[2]), σ,
|
Conv(init(k..., ch...), zeros(ch[2]), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
@treelike Conv
|
@functor Conv
|
||||||
|
|
||||||
function (c::Conv)(x::AbstractArray)
|
function (c::Conv)(x::AbstractArray)
|
||||||
# TODO: breaks gpu broadcast :(
|
# TODO: breaks gpu broadcast :(
|
||||||
@ -102,7 +102,7 @@ ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
|
|||||||
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
|
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
@treelike ConvTranspose
|
@functor ConvTranspose
|
||||||
|
|
||||||
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||||
# Calculate size of "input", from ∇conv_data()'s perspective...
|
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||||
@ -180,7 +180,7 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
@treelike DepthwiseConv
|
@functor DepthwiseConv
|
||||||
|
|
||||||
function (c::DepthwiseConv)(x)
|
function (c::DepthwiseConv)(x)
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
@ -244,7 +244,7 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
|||||||
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
|
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
@treelike CrossCor
|
@functor CrossCor
|
||||||
|
|
||||||
function crosscor(x, w, ddims::DenseConvDims)
|
function crosscor(x, w, ddims::DenseConvDims)
|
||||||
ddims = DenseConvDims(ddims, F=true)
|
ddims = DenseConvDims(ddims, F=true)
|
||||||
|
@ -82,7 +82,7 @@ end
|
|||||||
LayerNorm(h::Integer) =
|
LayerNorm(h::Integer) =
|
||||||
LayerNorm(Diagonal(h))
|
LayerNorm(Diagonal(h))
|
||||||
|
|
||||||
@treelike LayerNorm
|
@functor LayerNorm
|
||||||
|
|
||||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ function (m::Recur)(xs...)
|
|||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
@treelike Recur cell, init
|
@functor Recur cell, init
|
||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ end
|
|||||||
|
|
||||||
hidden(m::RNNCell) = m.h
|
hidden(m::RNNCell) = m.h
|
||||||
|
|
||||||
@treelike RNNCell
|
@functor RNNCell
|
||||||
|
|
||||||
function Base.show(io::IO, l::RNNCell)
|
function Base.show(io::IO, l::RNNCell)
|
||||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||||
@ -128,7 +128,7 @@ end
|
|||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
@treelike LSTMCell
|
@functor LSTMCell
|
||||||
|
|
||||||
Base.show(io::IO, l::LSTMCell) =
|
Base.show(io::IO, l::LSTMCell) =
|
||||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||||
@ -169,7 +169,7 @@ end
|
|||||||
|
|
||||||
hidden(m::GRUCell) = m.h
|
hidden(m::GRUCell) = m.h
|
||||||
|
|
||||||
@treelike GRUCell
|
@functor GRUCell
|
||||||
|
|
||||||
Base.show(io::IO, l::GRUCell) =
|
Base.show(io::IO, l::GRUCell) =
|
||||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||||
|
@ -13,7 +13,7 @@ end
|
|||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
||||||
rnn = R(10, 5)
|
rnn = R(10, 5)
|
||||||
curnn = mapleaves(gpu, rnn)
|
curnn = fmap(gpu, rnn)
|
||||||
|
|
||||||
Flux.reset!(rnn)
|
Flux.reset!(rnn)
|
||||||
Flux.reset!(curnn)
|
Flux.reset!(curnn)
|
||||||
|
Loading…
Reference in New Issue
Block a user