Merge branch 'master' into tb/cuarrays_dnn

This commit is contained in:
Mike Innes 2019-09-27 14:58:32 +01:00
commit b90b02872f
18 changed files with 195 additions and 171 deletions

View File

@ -5,7 +5,7 @@ variables:
CI_IMAGE_TAG: 'cuda' CI_IMAGE_TAG: 'cuda'
include: include:
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml' - 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v4/common.yml'
.flux: .flux:
extends: .test extends: .test
@ -13,25 +13,39 @@ include:
- julia -e 'using InteractiveUtils; - julia -e 'using InteractiveUtils;
versioninfo()' versioninfo()'
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325 - mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
- julia -e 'using Pkg;
Pkg.add("CuArrays");'
- julia --project -e 'using Pkg; - julia --project -e 'using Pkg;
Pkg.instantiate(); Pkg.instantiate();
Pkg.build(); Pkg.build();
Pkg.test(; coverage=true);' Pkg.test(; coverage=true);'
test:v1.0: test:v1.0:
extends: .flux extends: .flux
variables: variables:
CI_VERSION_TAG: 'v1.0' CI_VERSION_TAG: 'v1.0'
only:
- staging
- trying
test:v1.1: test:v1.1:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.1'
test:v1.2:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.2'
test:v1.3:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.3'
test:v1.0:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.0'
test:dev:
extends: .flux extends: .flux
variables: variables:
CI_VERSION_TAG: 'v1.1' CI_VERSION_TAG: 'dev'
only:
- staging allow_failure: true
- trying

View File

@ -149,9 +149,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FFTW]] [[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "03f8776fbdae28c20c0d1d2ae4e090cd1dfcd247" git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.0.0" version = "1.0.1"
[[FillArrays]] [[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"] deps = ["LinearAlgebra", "Random", "SparseArrays"]
@ -390,7 +390,7 @@ version = "0.8.3"
[[Zygote]] [[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "ce6d7142d665b1e4c71c678fa7db4da3bbc6743f" git-tree-sha1 = "38241b40ebd8748bcacad5e6c7ba3ab3cc7a15c9"
repo-rev = "master" repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git" repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
@ -398,6 +398,8 @@ version = "0.3.4"
[[ZygoteRules]] [[ZygoteRules]]
deps = ["MacroTools"] deps = ["MacroTools"]
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6" git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0"
repo-rev = "master"
repo-url = "https://github.com/FluxML/ZygoteRules.jl.git"
uuid = "700de1a5-db45-46bc-99cf-38207098b444" uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.1.0" version = "0.2.0"

View File

@ -23,13 +23,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat] [compat]
CUDAapi = "1.1" CUDAapi = "1.1"
CuArrays = "1.2" CuArrays = "1.2"
NNlib = "0.6" NNlib = "0.6"
Zygote = "0.3" Zygote = "0.3"
julia = "1.1" julia = "1"
[extras] [extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

View File

@ -25,16 +25,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before. Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once. If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
```julia ```julia
d = Dense(10, 5, σ) d = Dense(10, 5, σ)
d = mapleaves(cu, d) d = fmap(cu, d)
d.W # Tracked CuArray d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapleaves(cu, m) m = fmap(cu, m)
d(cu(rand(10))) d(cu(rand(10)))
``` ```

View File

@ -215,7 +215,7 @@ m(5) # => 26
Flux provides a set of helpers for custom layers, which you can enable by calling Flux provides a set of helpers for custom layers, which you can enable by calling
```julia ```julia
Flux.@treelike Affine Flux.@functor Affine
``` ```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -6,12 +6,12 @@ using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient, forward using Zygote: Params, @adjoint, gradient, pullback
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, mapleaves, cpu, gpu, f32, f64 SkipConnection, params, fmap, cpu, gpu, f32, f64
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
@ -35,7 +35,7 @@ end
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")
include("treelike.jl") include("functor.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/basic.jl") include("layers/basic.jl")

91
src/functor.jl Normal file
View File

@ -0,0 +1,91 @@
import Adapt: adapt, adapt_storage
using Zygote: IdSet
functor(x) = (), _ -> x
functor(x::Tuple) = x, y -> y
functor(x::NamedTuple) = x, y -> y
functor(x::AbstractArray) = x, y -> y
functor(x::AbstractArray{<:Number}) = (), _ -> x
function makefunctor(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
end
end
function functorm(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end
macro functor(args...)
functorm(args...)
end
isleaf(x) = functor(x)[1] === ()
function fmap1(f, x)
func, re = functor(x)
re(map(f, func))
end
function fmap(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
end
trainable(m) = functor(m)[1]
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
end
function params(m...)
ps = Params()
params!(ps, m)
return ps
end
# Deprecated stuff
macro treelike(args...)
functorm(args...)
end
mapleaves(f, x) = fmap(f, x)
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end
# CPU/GPU movement conveniences
cpu(m) = fmap(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
end
gpu(x) = fmap(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)

View File

@ -24,8 +24,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex Base.iterate, Base.lastindex
children(c::Chain) = c.layers functor(c::Chain) = c.layers, ls -> Chain(ls...)
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
applychain(::Tuple{}, x) = x applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
@ -92,7 +91,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
return Dense(initW(out, in), initb(out), σ) return Dense(initW(out, in), initb(out), σ)
end end
@treelike Dense @functor Dense
function (a::Dense)(x::AbstractArray) function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ W, b, σ = a.W, a.b, a.σ
@ -131,7 +130,7 @@ end
Diagonal(in::Integer; initα = ones, initβ = zeros) = Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(initα(in), initβ(in)) Diagonal(initα(in), initβ(in))
@treelike Diagonal @functor Diagonal
function (a::Diagonal)(x) function (a::Diagonal)(x)
α, β = a.α, a.β α, β = a.α, a.β
@ -184,24 +183,29 @@ function Maxout(f, n_alts)
return Maxout(over) return Maxout(over)
end end
@treelike Maxout @functor Maxout
function (mo::Maxout)(input::AbstractArray) function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end end
""" """
SkipConnection(layers...) SkipConnection(layers, connection)
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers Creates a Skip Connection, of a layer or `Chain` of consecutive layers
and a shortcut connection linking the input to the block to the plus a shortcut connection. The connection function will combine the result of the layers
output through a user-supplied callable. with the original input, to give the final output.
`SkipConnection` requires the output dimension to be the same as the input. The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
and requires the output of the layers to be the same shape as the input.
Here is a more complicated example:
```
m = Conv((3,3), 4=>7, pad=(1,1))
x = ones(5,5,4,10);
size(m(x)) == (5, 5, 7, 10)
A 'ResNet'-type skip-connection with identity shortcut would simply be sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
```julia size(sm(x)) == (5, 5, 11, 10)
SkipConnection(layer, (a,b) -> a + b)
``` ```
""" """
struct SkipConnection struct SkipConnection
@ -209,15 +213,12 @@ struct SkipConnection
connection #user can pass arbitrary connections here, such as (a,b) -> a + b connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end end
@treelike SkipConnection @functor SkipConnection
function (skip::SkipConnection)(input) function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input) skip.connection(skip.layers(input), input)
end end
function Base.show(io::IO, b::SkipConnection) function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(") print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
join(io, b.layers, ", ")
print(io, ")")
end end

View File

@ -45,7 +45,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
Conv(init(k..., ch...), zeros(ch[2]), σ, Conv(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@treelike Conv @functor Conv
function (c::Conv)(x::AbstractArray) function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :( # TODO: breaks gpu broadcast :(
@ -102,7 +102,7 @@ ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ, ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose @functor ConvTranspose
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective... # Calculate size of "input", from ∇conv_data()'s perspective...
@ -180,7 +180,7 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
) )
end end
@treelike DepthwiseConv @functor DepthwiseConv
function (c::DepthwiseConv)(x) function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@ -244,7 +244,7 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
CrossCor(init(k..., ch...), zeros(ch[2]), σ, CrossCor(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@treelike CrossCor @functor CrossCor
function crosscor(x, w, ddims::DenseConvDims) function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true) ddims = DenseConvDims(ddims, F=true)

View File

@ -82,7 +82,7 @@ end
LayerNorm(h::Integer) = LayerNorm(h::Integer) =
LayerNorm(Diagonal(h)) LayerNorm(Diagonal(h))
@treelike LayerNorm @functor LayerNorm
(a::LayerNorm)(x) = a.diag(normalise(x)) (a::LayerNorm)(x) = a.diag(normalise(x))
@ -134,6 +134,8 @@ BatchNorm(chs::Integer, λ = identity;
BatchNorm(λ, initβ(chs), initγ(chs), BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum) zeros(chs), ones(chs), ϵ, momentum)
trainable(bn::BatchNorm) = (bn.β, bn.γ)
function (BN::BatchNorm)(x) function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) || size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
@ -166,11 +168,7 @@ function (BN::BatchNorm)(x)
end end
end end
children(BN::BatchNorm) = @functor BatchNorm
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
function Base.show(io::IO, l::BatchNorm) function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))") print(io, "BatchNorm($(join(size(l.β), ", "))")
@ -224,6 +222,8 @@ InstanceNorm(chs::Integer, λ = identity;
InstanceNorm(λ, initβ(chs), initγ(chs), InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum) zeros(chs), ones(chs), ϵ, momentum)
trainable(in::InstanceNorm) = (in.β, in.γ)
function (in::InstanceNorm)(x) function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) || size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))") error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
@ -261,11 +261,7 @@ function (in::InstanceNorm)(x)
end end
end end
children(in::InstanceNorm) = @functor InstanceNorm
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
function Base.show(io::IO, l::InstanceNorm) function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))") print(io, "InstanceNorm($(join(size(l.β), ", "))")
@ -311,6 +307,8 @@ GroupNorm(chs::Integer, G::Integer, λ = identity;
GroupNorm(G, λ, initβ(chs), initγ(chs), GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum) zeros(G,1), ones(G,1), ϵ, momentum)
trainable(gn::GroupNorm) = (gn.β, gn.γ)
function(gn::GroupNorm)(x) function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work") ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
@ -360,11 +358,7 @@ function(gn::GroupNorm)(x)
end end
end end
children(gn::GroupNorm) = @functor GroupNorm
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum)
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum)
function Base.show(io::IO, l::GroupNorm) function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))") print(io, "GroupNorm($(join(size(l.β), ", "))")

View File

@ -38,7 +38,7 @@ function (m::Recur)(xs...)
return y return y
end end
@treelike Recur cell, init @functor Recur cell, init
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
@ -52,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell) rnn.state = hidden(rnn.cell)
""" """
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m) reset!(m::Recur) = (m.state = m.init)
reset!(m) = foreach(reset!, functor(m)[1])
flip(f, xs) = reverse(f.(reverse(xs))) flip(f, xs) = reverse(f.(reverse(xs)))
@ -79,7 +80,7 @@ end
hidden(m::RNNCell) = m.h hidden(m::RNNCell) = m.h
@treelike RNNCell @functor RNNCell
function Base.show(io::IO, l::RNNCell) function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
@ -127,7 +128,7 @@ end
hidden(m::LSTMCell) = (m.h, m.c) hidden(m::LSTMCell) = (m.h, m.c)
@treelike LSTMCell @functor LSTMCell
Base.show(io::IO, l::LSTMCell) = Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
@ -168,7 +169,7 @@ end
hidden(m::GRUCell) = m.h hidden(m::GRUCell) = m.h
@treelike GRUCell @functor GRUCell
Base.show(io::IO, l::GRUCell) = Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

View File

@ -1,86 +0,0 @@
import Adapt: adapt, adapt_storage
import Zygote: IdSet
children(x) = ()
mapchildren(f, x) = x
children(x::Tuple) = x
children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
end
end
macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
end
isleaf(x) = isempty(children(x))
function mapleaves(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
function prefor(f, x; seen = IdSet())
x seen && return
push!(seen, x)
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))
return
end
function params(m)
ps = Params()
prefor(p ->
p isa AbstractArray{<:Real} &&
!any(p -> p === p, ps) && push!(ps, p),
m)
return ps
end
params(m...) = params(m)
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end
# CPU/GPU movement conveniences
cpu(m) = mapleaves(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
end
gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
x isa Union{AbstractArray,Number} ? f(x) : x
end
end

View File

@ -1,4 +1,5 @@
using Flux, CuArrays, Test using Flux, Test
using Flux.CuArrays
using Flux: gpu using Flux: gpu
@info "Testing GPU Support" @info "Testing GPU Support"

View File

@ -1,5 +1,5 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Flux: forward using Flux: pullback
@testset "CUDNN BatchNorm" begin @testset "CUDNN BatchNorm" begin
@testset "4D Input" begin @testset "4D Input" begin
@ -8,8 +8,8 @@ using Flux: forward
cx = gpu(x) cx = gpu(x)
cm = gpu(m) cm = gpu(m)
y, back = forward((m, x) -> m(x), m, x) y, back = pullback((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx) cy, cback = pullback((m, x) -> m(x), cm, cx)
@test cpu(cy) y @test cpu(cy) y
@ -28,8 +28,8 @@ using Flux: forward
cx = gpu(x) cx = gpu(x)
cm = gpu(m) cm = gpu(m)
y, back = forward((m, x) -> m(x), m, x) y, back = pullback((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx) cy, cback = pullback((m, x) -> m(x), cm, cx)
@test cpu(cy) y @test cpu(cy) y

View File

@ -1,5 +1,5 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Flux: forward using Flux: pullback
@testset for R in [RNN, GRU, LSTM] @testset for R in [RNN, GRU, LSTM]
m = R(10, 5) |> gpu m = R(10, 5) |> gpu
@ -13,7 +13,7 @@ end
@testset "RNN" begin @testset "RNN" begin
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5) @testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
rnn = R(10, 5) rnn = R(10, 5)
curnn = mapleaves(gpu, rnn) curnn = fmap(gpu, rnn)
Flux.reset!(rnn) Flux.reset!(rnn)
Flux.reset!(curnn) Flux.reset!(curnn)
@ -22,8 +22,8 @@ end
rand(10, batch_size) rand(10, batch_size)
cux = gpu(x) cux = gpu(x)
y, back = forward((r, x) -> r(x), rnn, x) y, back = pullback((r, x) -> r(x), rnn, x)
cuy, cuback = forward((r, x) -> r(x), curnn, cux) cuy, cuback = pullback((r, x) -> r(x), curnn, cux)
@test y collect(cuy) @test y collect(cuy)
@test haskey(Flux.CUDA.descs, curnn.cell) @test haskey(Flux.CUDA.descs, curnn.cell)

View File

@ -1,7 +1,7 @@
using Flux, Test, Statistics using Flux, Test, Statistics
using Zygote: forward using Zygote: pullback
trainmode(f, x...) = forward(f, x...)[1] trainmode(f, x...) = pullback(f, x...)[1]
trainmode(f) = (x...) -> trainmode(f, x...) trainmode(f) = (x...) -> trainmode(f, x...)
@testset "Dropout" begin @testset "Dropout" begin
@ -42,6 +42,8 @@ end
let m = BatchNorm(2), x = [1.0 3.0 5.0; let m = BatchNorm(2), x = [1.0 3.0 5.0;
2.0 4.0 6.0] 2.0 4.0 6.0]
@test length(params(m)) == 2
@test m.β == [0, 0] # initβ(2) @test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2) @test m.γ == [1, 1] # initγ(2)
# initial m.σ is 1 # initial m.σ is 1
@ -109,7 +111,9 @@ end
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests # begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2), let m = InstanceNorm(2), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes) x = reshape(collect(1:prod(sizes)), sizes)
@test length(params(m)) == 2
x = Float64.(x) x = Float64.(x)
@test m.β == [0, 0] # initβ(2) @test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2) @test m.γ == [1, 1] # initγ(2)
@ -192,7 +196,9 @@ end
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2), let m = GroupNorm(4,2), sizes = (3,4,2),
x = reshape(collect(1:prod(sizes)), sizes) x = reshape(collect(1:prod(sizes)), sizes)
@test length(params(m)) == 2
x = Float64.(x) x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32) @test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32) @test m.γ == [1, 1, 1, 1] # initγ(32)

View File

@ -55,7 +55,7 @@ const ϵ = 1e-7
y = rand(T, 2) y = rand(T, 2)
ŷ = rand(T, 2) ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy) for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.forward(f, , y) fwd, back = Flux.pullback(f, , y)
@test fwd isa T @test fwd isa T
@test eltype(back(one(T))[1]) == T @test eltype(back(one(T))[1]) == T
end end

View File

@ -83,7 +83,6 @@ end
# Self-referential array. Just want params, no stack overflow pls. # Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m] r = Any[nothing,m]
Flux.children(a::Vector{Any}) = Tuple(a)
r[1] = r r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)] @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
end end