Merge branch 'master' into tb/cuarrays_dnn

This commit is contained in:
Mike Innes 2019-09-27 14:58:32 +01:00
commit b90b02872f
18 changed files with 195 additions and 171 deletions

View File

@ -5,7 +5,7 @@ variables:
CI_IMAGE_TAG: 'cuda'
include:
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml'
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v4/common.yml'
.flux:
extends: .test
@ -13,25 +13,39 @@ include:
- julia -e 'using InteractiveUtils;
versioninfo()'
- mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325
- julia -e 'using Pkg;
Pkg.add("CuArrays");'
- julia --project -e 'using Pkg;
Pkg.instantiate();
Pkg.build();
Pkg.test(; coverage=true);'
test:v1.0:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.0'
only:
- staging
- trying
extends: .flux
variables:
CI_VERSION_TAG: 'v1.0'
test:v1.1:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.1'
test:v1.2:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.2'
test:v1.3:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.3'
test:v1.0:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.0'
test:dev:
extends: .flux
variables:
CI_VERSION_TAG: 'v1.1'
only:
- staging
- trying
CI_VERSION_TAG: 'dev'
allow_failure: true

View File

@ -149,9 +149,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "03f8776fbdae28c20c0d1d2ae4e090cd1dfcd247"
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.0.0"
version = "1.0.1"
[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
@ -390,7 +390,7 @@ version = "0.8.3"
[[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "ce6d7142d665b1e4c71c678fa7db4da3bbc6743f"
git-tree-sha1 = "38241b40ebd8748bcacad5e6c7ba3ab3cc7a15c9"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
@ -398,6 +398,8 @@ version = "0.3.4"
[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6"
git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0"
repo-rev = "master"
repo-url = "https://github.com/FluxML/ZygoteRules.jl.git"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.1.0"
version = "0.2.0"

View File

@ -23,13 +23,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
CUDAapi = "1.1"
CuArrays = "1.2"
NNlib = "0.6"
Zygote = "0.3"
julia = "1.1"
julia = "1"
[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

View File

@ -25,16 +25,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
```julia
d = Dense(10, 5, σ)
d = mapleaves(cu, d)
d = fmap(cu, d)
d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapleaves(cu, m)
m = fmap(cu, m)
d(cu(rand(10)))
```

View File

@ -215,7 +215,7 @@ m(5) # => 26
Flux provides a set of helpers for custom layers, which you can enable by calling
```julia
Flux.@treelike Affine
Flux.@functor Affine
```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -6,12 +6,12 @@ using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, forward
using Zygote: Params, @adjoint, gradient, pullback
export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, mapleaves, cpu, gpu, f32, f64
SkipConnection, params, fmap, cpu, gpu, f32, f64
include("optimise/Optimise.jl")
using .Optimise
@ -35,7 +35,7 @@ end
include("utils.jl")
include("onehot.jl")
include("treelike.jl")
include("functor.jl")
include("layers/stateless.jl")
include("layers/basic.jl")

91
src/functor.jl Normal file
View File

@ -0,0 +1,91 @@
import Adapt: adapt, adapt_storage
using Zygote: IdSet
functor(x) = (), _ -> x
functor(x::Tuple) = x, y -> y
functor(x::NamedTuple) = x, y -> y
functor(x::AbstractArray) = x, y -> y
functor(x::AbstractArray{<:Number}) = (), _ -> x
function makefunctor(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
end
end
function functorm(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end
macro functor(args...)
functorm(args...)
end
isleaf(x) = functor(x)[1] === ()
function fmap1(f, x)
func, re = functor(x)
re(map(f, func))
end
function fmap(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
end
trainable(m) = functor(m)[1]
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
end
function params(m...)
ps = Params()
params!(ps, m)
return ps
end
# Deprecated stuff
macro treelike(args...)
functorm(args...)
end
mapleaves(f, x) = fmap(f, x)
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end
# CPU/GPU movement conveniences
cpu(m) = fmap(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
end
gpu(x) = fmap(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)

View File

@ -24,8 +24,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
functor(c::Chain) = c.layers, ls -> Chain(ls...)
applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
@ -92,7 +91,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
return Dense(initW(out, in), initb(out), σ)
end
@treelike Dense
@functor Dense
function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
@ -131,7 +130,7 @@ end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(initα(in), initβ(in))
@treelike Diagonal
@functor Diagonal
function (a::Diagonal)(x)
α, β = a.α, a.β
@ -184,24 +183,29 @@ function Maxout(f, n_alts)
return Maxout(over)
end
@treelike Maxout
@functor Maxout
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end
"""
SkipConnection(layers...)
SkipConnection(layers, connection)
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied callable.
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
plus a shortcut connection. The connection function will combine the result of the layers
with the original input, to give the final output.
`SkipConnection` requires the output dimension to be the same as the input.
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
and requires the output of the layers to be the same shape as the input.
Here is a more complicated example:
```
m = Conv((3,3), 4=>7, pad=(1,1))
x = ones(5,5,4,10);
size(m(x)) == (5, 5, 7, 10)
A 'ResNet'-type skip-connection with identity shortcut would simply be
```julia
SkipConnection(layer, (a,b) -> a + b)
sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
size(sm(x)) == (5, 5, 11, 10)
```
"""
struct SkipConnection
@ -209,15 +213,12 @@ struct SkipConnection
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end
@treelike SkipConnection
@functor SkipConnection
function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input)
end
function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(")
join(io, b.layers, ", ")
print(io, ")")
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
end

View File

@ -45,7 +45,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
Conv(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike Conv
@functor Conv
function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
@ -102,7 +102,7 @@ ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose
@functor ConvTranspose
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
@ -180,7 +180,7 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
)
end
@treelike DepthwiseConv
@functor DepthwiseConv
function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@ -244,7 +244,7 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike CrossCor
@functor CrossCor
function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)

View File

@ -82,7 +82,7 @@ end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
@treelike LayerNorm
@functor LayerNorm
(a::LayerNorm)(x) = a.diag(normalise(x))
@ -134,6 +134,8 @@ BatchNorm(chs::Integer, λ = identity;
BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)
trainable(bn::BatchNorm) = (bn.β, bn.γ)
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
@ -166,11 +168,7 @@ function (BN::BatchNorm)(x)
end
end
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
@functor BatchNorm
function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))")
@ -224,6 +222,8 @@ InstanceNorm(chs::Integer, λ = identity;
InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)
trainable(in::InstanceNorm) = (in.β, in.γ)
function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
@ -261,11 +261,7 @@ function (in::InstanceNorm)(x)
end
end
children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
@functor InstanceNorm
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
@ -311,6 +307,8 @@ GroupNorm(chs::Integer, G::Integer, λ = identity;
GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum)
trainable(gn::GroupNorm) = (gn.β, gn.γ)
function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
@ -360,11 +358,7 @@ function(gn::GroupNorm)(x)
end
end
children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum)
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum)
@functor GroupNorm
function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")

View File

@ -38,7 +38,7 @@ function (m::Recur)(xs...)
return y
end
@treelike Recur cell, init
@functor Recur cell, init
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
@ -52,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
"""
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
reset!(m::Recur) = (m.state = m.init)
reset!(m) = foreach(reset!, functor(m)[1])
flip(f, xs) = reverse(f.(reverse(xs)))
@ -79,7 +80,7 @@ end
hidden(m::RNNCell) = m.h
@treelike RNNCell
@functor RNNCell
function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
@ -127,7 +128,7 @@ end
hidden(m::LSTMCell) = (m.h, m.c)
@treelike LSTMCell
@functor LSTMCell
Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
@ -168,7 +169,7 @@ end
hidden(m::GRUCell) = m.h
@treelike GRUCell
@functor GRUCell
Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

View File

@ -1,86 +0,0 @@
import Adapt: adapt, adapt_storage
import Zygote: IdSet
children(x) = ()
mapchildren(f, x) = x
children(x::Tuple) = x
children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
end
end
macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
end
isleaf(x) = isempty(children(x))
function mapleaves(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
function prefor(f, x; seen = IdSet())
x seen && return
push!(seen, x)
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))
return
end
function params(m)
ps = Params()
prefor(p ->
p isa AbstractArray{<:Real} &&
!any(p -> p === p, ps) && push!(ps, p),
m)
return ps
end
params(m...) = params(m)
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end
# CPU/GPU movement conveniences
cpu(m) = mapleaves(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
end
gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
x isa Union{AbstractArray,Number} ? f(x) : x
end
end

View File

@ -1,4 +1,5 @@
using Flux, CuArrays, Test
using Flux, Test
using Flux.CuArrays
using Flux: gpu
@info "Testing GPU Support"

View File

@ -1,5 +1,5 @@
using Flux, CuArrays, Test
using Flux: forward
using Flux: pullback
@testset "CUDNN BatchNorm" begin
@testset "4D Input" begin
@ -8,8 +8,8 @@ using Flux: forward
cx = gpu(x)
cm = gpu(m)
y, back = forward((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx)
y, back = pullback((m, x) -> m(x), m, x)
cy, cback = pullback((m, x) -> m(x), cm, cx)
@test cpu(cy) y
@ -28,8 +28,8 @@ using Flux: forward
cx = gpu(x)
cm = gpu(m)
y, back = forward((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx)
y, back = pullback((m, x) -> m(x), m, x)
cy, cback = pullback((m, x) -> m(x), cm, cx)
@test cpu(cy) y

View File

@ -1,5 +1,5 @@
using Flux, CuArrays, Test
using Flux: forward
using Flux: pullback
@testset for R in [RNN, GRU, LSTM]
m = R(10, 5) |> gpu
@ -13,7 +13,7 @@ end
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
rnn = R(10, 5)
curnn = mapleaves(gpu, rnn)
curnn = fmap(gpu, rnn)
Flux.reset!(rnn)
Flux.reset!(curnn)
@ -22,8 +22,8 @@ end
rand(10, batch_size)
cux = gpu(x)
y, back = forward((r, x) -> r(x), rnn, x)
cuy, cuback = forward((r, x) -> r(x), curnn, cux)
y, back = pullback((r, x) -> r(x), rnn, x)
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)
@test y collect(cuy)
@test haskey(Flux.CUDA.descs, curnn.cell)

View File

@ -1,7 +1,7 @@
using Flux, Test, Statistics
using Zygote: forward
using Zygote: pullback
trainmode(f, x...) = forward(f, x...)[1]
trainmode(f, x...) = pullback(f, x...)[1]
trainmode(f) = (x...) -> trainmode(f, x...)
@testset "Dropout" begin
@ -42,6 +42,8 @@ end
let m = BatchNorm(2), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
@test length(params(m)) == 2
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
# initial m.σ is 1
@ -109,7 +111,9 @@ end
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)
x = reshape(collect(1:prod(sizes)), sizes)
@test length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
@ -192,7 +196,9 @@ end
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2),
x = reshape(collect(1:prod(sizes)), sizes)
x = reshape(collect(1:prod(sizes)), sizes)
@test length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32)

View File

@ -55,7 +55,7 @@ const ϵ = 1e-7
y = rand(T, 2)
ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.forward(f, , y)
fwd, back = Flux.pullback(f, , y)
@test fwd isa T
@test eltype(back(one(T))[1]) == T
end

View File

@ -83,7 +83,6 @@ end
# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
Flux.children(a::Vector{Any}) = Tuple(a)
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
end