diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 05217e81..9af14c6a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -5,7 +5,7 @@ variables: CI_IMAGE_TAG: 'cuda' include: - - 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v3/common.yml' + - 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v4/common.yml' .flux: extends: .test @@ -13,25 +13,39 @@ include: - julia -e 'using InteractiveUtils; versioninfo()' - mkdir $JULIA_DEPOT_PATH # Pkg3.jl#325 - - julia -e 'using Pkg; - Pkg.add("CuArrays");' - julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.build(); Pkg.test(; coverage=true);' test:v1.0: - extends: .flux - variables: - CI_VERSION_TAG: 'v1.0' - only: - - staging - - trying + extends: .flux + variables: + CI_VERSION_TAG: 'v1.0' test:v1.1: + extends: .flux + variables: + CI_VERSION_TAG: 'v1.1' + +test:v1.2: + extends: .flux + variables: + CI_VERSION_TAG: 'v1.2' + +test:v1.3: + extends: .flux + variables: + CI_VERSION_TAG: 'v1.3' + +test:v1.0: + extends: .flux + variables: + CI_VERSION_TAG: 'v1.0' + +test:dev: extends: .flux variables: - CI_VERSION_TAG: 'v1.1' - only: - - staging - - trying + CI_VERSION_TAG: 'dev' + + allow_failure: true diff --git a/Manifest.toml b/Manifest.toml index 4d825f17..87f5075f 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -149,9 +149,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FFTW]] deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] -git-tree-sha1 = "03f8776fbdae28c20c0d1d2ae4e090cd1dfcd247" +git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.0.0" +version = "1.0.1" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] @@ -390,7 +390,7 @@ version = "0.8.3" [[Zygote]] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "ce6d7142d665b1e4c71c678fa7db4da3bbc6743f" +git-tree-sha1 = "38241b40ebd8748bcacad5e6c7ba3ab3cc7a15c9" repo-rev = "master" repo-url = "https://github.com/FluxML/Zygote.jl.git" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -398,6 +398,8 @@ version = "0.3.4" [[ZygoteRules]] deps = ["MacroTools"] -git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6" +git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0" +repo-rev = "master" +repo-url = "https://github.com/FluxML/ZygoteRules.jl.git" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.1.0" +version = "0.2.0" diff --git a/Project.toml b/Project.toml index 7cd78984..5e357c59 100644 --- a/Project.toml +++ b/Project.toml @@ -23,13 +23,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] CUDAapi = "1.1" CuArrays = "1.2" NNlib = "0.6" Zygote = "0.3" -julia = "1.1" +julia = "1" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/src/gpu.md b/docs/src/gpu.md index aed33f4e..bb13fdd1 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -25,16 +25,16 @@ loss(x, y) # ~ 3 Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before. -If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once. +If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once. ```julia d = Dense(10, 5, σ) -d = mapleaves(cu, d) +d = fmap(cu, d) d.W # Tracked CuArray d(cu(rand(10))) # CuArray output m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) -m = mapleaves(cu, m) +m = fmap(cu, m) d(cu(rand(10))) ``` diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index ddd81992..d83fc462 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -215,7 +215,7 @@ m(5) # => 26 Flux provides a set of helpers for custom layers, which you can enable by calling ```julia -Flux.@treelike Affine +Flux.@functor Affine ``` This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). diff --git a/src/Flux.jl b/src/Flux.jl index 9d1fbfc5..0b57f81d 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,12 +6,12 @@ using Base: tail using Zygote, MacroTools, Juno, Reexport, Statistics, Random using MacroTools: @forward @reexport using NNlib -using Zygote: Params, @adjoint, gradient, forward +using Zygote: Params, @adjoint, gradient, pullback export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, mapleaves, cpu, gpu, f32, f64 + SkipConnection, params, fmap, cpu, gpu, f32, f64 include("optimise/Optimise.jl") using .Optimise @@ -35,7 +35,7 @@ end include("utils.jl") include("onehot.jl") -include("treelike.jl") +include("functor.jl") include("layers/stateless.jl") include("layers/basic.jl") diff --git a/src/functor.jl b/src/functor.jl new file mode 100644 index 00000000..73483ab9 --- /dev/null +++ b/src/functor.jl @@ -0,0 +1,91 @@ +import Adapt: adapt, adapt_storage +using Zygote: IdSet + +functor(x) = (), _ -> x + +functor(x::Tuple) = x, y -> y +functor(x::NamedTuple) = x, y -> y + +functor(x::AbstractArray) = x, y -> y +functor(x::AbstractArray{<:Number}) = (), _ -> x + +function makefunctor(m::Module, T, fs = fieldnames(T)) + @eval m begin + Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...) + end +end + +function functorm(T, fs = nothing) + fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)") + fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] + :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) +end + +macro functor(args...) + functorm(args...) +end + +isleaf(x) = functor(x)[1] === () + +function fmap1(f, x) + func, re = functor(x) + re(map(f, func)) +end + +function fmap(f, x; cache = IdDict()) + haskey(cache, x) && return cache[x] + cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) +end + +trainable(m) = functor(m)[1] + +params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x) + +function params!(p::Params, x, seen = IdSet()) + x in seen && return + push!(seen, x) + for child in trainable(x) + params!(p, child, seen) + end +end + +function params(m...) + ps = Params() + params!(ps, m) + return ps +end + +# Deprecated stuff +macro treelike(args...) + functorm(args...) +end +mapleaves(f, x) = fmap(f, x) + +function loadparams!(m, xs) + for (p, x) in zip(params(m), xs) + size(p) == size(x) || + error("Expected param size $(size(p)), got $(size(x))") + copyto!(p, x) + end +end + +# CPU/GPU movement conveniences + +cpu(m) = fmap(x -> adapt(Array, x), m) + +const gpu_adaptor = if has_cuarrays() + CuArrays.cu +else + identity +end + +gpu(x) = fmap(gpu_adaptor, x) + +# Precision + +adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) + +paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m) + +f32(m) = paramtype(Float32, m) +f64(m) = paramtype(Float64, m) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0cebead1..f42a9619 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -24,8 +24,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex -children(c::Chain) = c.layers -mapchildren(f, c::Chain) = Chain(f.(c.layers)...) +functor(c::Chain) = c.layers, ls -> Chain(ls...) applychain(::Tuple{}, x) = x applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) @@ -92,7 +91,7 @@ function Dense(in::Integer, out::Integer, σ = identity; return Dense(initW(out, in), initb(out), σ) end -@treelike Dense +@functor Dense function (a::Dense)(x::AbstractArray) W, b, σ = a.W, a.b, a.σ @@ -131,7 +130,7 @@ end Diagonal(in::Integer; initα = ones, initβ = zeros) = Diagonal(initα(in), initβ(in)) -@treelike Diagonal +@functor Diagonal function (a::Diagonal)(x) α, β = a.α, a.β @@ -184,24 +183,29 @@ function Maxout(f, n_alts) return Maxout(over) end -@treelike Maxout +@functor Maxout function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end """ - SkipConnection(layers...) + SkipConnection(layers, connection) -Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers -and a shortcut connection linking the input to the block to the -output through a user-supplied callable. +Creates a Skip Connection, of a layer or `Chain` of consecutive layers +plus a shortcut connection. The connection function will combine the result of the layers +with the original input, to give the final output. -`SkipConnection` requires the output dimension to be the same as the input. +The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`, +and requires the output of the layers to be the same shape as the input. +Here is a more complicated example: +``` +m = Conv((3,3), 4=>7, pad=(1,1)) +x = ones(5,5,4,10); +size(m(x)) == (5, 5, 7, 10) -A 'ResNet'-type skip-connection with identity shortcut would simply be -```julia - SkipConnection(layer, (a,b) -> a + b) +sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3)) +size(sm(x)) == (5, 5, 11, 10) ``` """ struct SkipConnection @@ -209,15 +213,12 @@ struct SkipConnection connection #user can pass arbitrary connections here, such as (a,b) -> a + b end -@treelike SkipConnection +@functor SkipConnection function (skip::SkipConnection)(input) - #We apply the layers to the input and return the result of the application of the layers and the original input skip.connection(skip.layers(input), input) end function Base.show(io::IO, b::SkipConnection) - print(io, "SkipConnection(") - join(io, b.layers, ", ") - print(io, ")") + print(io, "SkipConnection(", b.layers, ", ", b.connection, ")") end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4361a389..519f129f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -45,7 +45,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; Conv(init(k..., ch...), zeros(ch[2]), σ, stride = stride, pad = pad, dilation = dilation) -@treelike Conv +@functor Conv function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( @@ -102,7 +102,7 @@ ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ, stride = stride, pad = pad, dilation = dilation) -@treelike ConvTranspose +@functor ConvTranspose function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) # Calculate size of "input", from ∇conv_data()'s perspective... @@ -180,7 +180,7 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = ) end -@treelike DepthwiseConv +@functor DepthwiseConv function (c::DepthwiseConv)(x) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) @@ -244,7 +244,7 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; CrossCor(init(k..., ch...), zeros(ch[2]), σ, stride = stride, pad = pad, dilation = dilation) -@treelike CrossCor +@functor CrossCor function crosscor(x, w, ddims::DenseConvDims) ddims = DenseConvDims(ddims, F=true) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 61a62adf..b421d3e7 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -82,7 +82,7 @@ end LayerNorm(h::Integer) = LayerNorm(Diagonal(h)) -@treelike LayerNorm +@functor LayerNorm (a::LayerNorm)(x) = a.diag(normalise(x)) @@ -134,6 +134,8 @@ BatchNorm(chs::Integer, λ = identity; BatchNorm(λ, initβ(chs), initγ(chs), zeros(chs), ones(chs), ϵ, momentum) +trainable(bn::BatchNorm) = (bn.β, bn.γ) + function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") @@ -166,11 +168,7 @@ function (BN::BatchNorm)(x) end end -children(BN::BatchNorm) = - (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum) - -mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum) +@functor BatchNorm function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") @@ -224,6 +222,8 @@ InstanceNorm(chs::Integer, λ = identity; InstanceNorm(λ, initβ(chs), initγ(chs), zeros(chs), ones(chs), ϵ, momentum) +trainable(in::InstanceNorm) = (in.β, in.γ) + function (in::InstanceNorm)(x) size(x, ndims(x)-1) == length(in.β) || error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))") @@ -261,11 +261,7 @@ function (in::InstanceNorm)(x) end end -children(in::InstanceNorm) = - (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum) - -mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) - InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum) +@functor InstanceNorm function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") @@ -311,6 +307,8 @@ GroupNorm(chs::Integer, G::Integer, λ = identity; GroupNorm(G, λ, initβ(chs), initγ(chs), zeros(G,1), ones(G,1), ϵ, momentum) +trainable(gn::GroupNorm) = (gn.β, gn.γ) + function(gn::GroupNorm)(x) size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work") @@ -360,11 +358,7 @@ function(gn::GroupNorm)(x) end end -children(gn::GroupNorm) = - (gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum) - -mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN) - GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum) +@functor GroupNorm function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index b5eea4a4..f2344af8 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -38,7 +38,7 @@ function (m::Recur)(xs...) return y end -@treelike Recur cell, init +@functor Recur cell, init Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -52,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to rnn.state = hidden(rnn.cell) """ -reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m) +reset!(m::Recur) = (m.state = m.init) +reset!(m) = foreach(reset!, functor(m)[1]) flip(f, xs) = reverse(f.(reverse(xs))) @@ -79,7 +80,7 @@ end hidden(m::RNNCell) = m.h -@treelike RNNCell +@functor RNNCell function Base.show(io::IO, l::RNNCell) print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) @@ -127,7 +128,7 @@ end hidden(m::LSTMCell) = (m.h, m.c) -@treelike LSTMCell +@functor LSTMCell Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") @@ -168,7 +169,7 @@ end hidden(m::GRUCell) = m.h -@treelike GRUCell +@functor GRUCell Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") diff --git a/src/treelike.jl b/src/treelike.jl deleted file mode 100644 index 42b10f23..00000000 --- a/src/treelike.jl +++ /dev/null @@ -1,86 +0,0 @@ -import Adapt: adapt, adapt_storage -import Zygote: IdSet - -children(x) = () -mapchildren(f, x) = x - -children(x::Tuple) = x -children(x::NamedTuple) = x -mapchildren(f, x::Tuple) = map(f, x) -mapchildren(f, x::NamedTuple) = map(f, x) - -function treelike(m::Module, T, fs = fieldnames(T)) - @eval m begin - Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),) - Flux.mapchildren(f, x::$T) = $T(f.($children(x))...) - end -end - -macro treelike(T, fs = nothing) - fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)") - fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] - :(treelike(@__MODULE__, $(esc(T)), $(fs...))) -end - -isleaf(x) = isempty(children(x)) - -function mapleaves(f, x; cache = IdDict()) - haskey(cache, x) && return cache[x] - cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) -end - -function prefor(f, x; seen = IdSet()) - x ∈ seen && return - push!(seen, x) - f(x) - foreach(x -> prefor(f, x, seen = seen), children(x)) - return -end - -function params(m) - ps = Params() - prefor(p -> - p isa AbstractArray{<:Real} && - !any(p′ -> p′ === p, ps) && push!(ps, p), - m) - return ps -end - -params(m...) = params(m) - -function loadparams!(m, xs) - for (p, x) in zip(params(m), xs) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") - copyto!(p, x) - end -end - -# CPU/GPU movement conveniences - -cpu(m) = mapleaves(x -> adapt(Array, x), m) - -const gpu_adaptor = if has_cuarrays() - CuArrays.cu -else - identity -end - -gpu(x) = mapleaves(gpu_adaptor, x) - -# Precision - -adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) - -paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m) - -f32(m) = paramtype(Float32, m) -f64(m) = paramtype(Float64, m) - -# General parameter map - -function mapparams(f, m) - mapleaves(m) do x - x isa Union{AbstractArray,Number} ? f(x) : x - end -end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 20399ef7..59bc7f50 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -1,4 +1,5 @@ -using Flux, CuArrays, Test +using Flux, Test +using Flux.CuArrays using Flux: gpu @info "Testing GPU Support" diff --git a/test/cuda/cudnn.jl b/test/cuda/cudnn.jl index a7fc244e..881e0b39 100644 --- a/test/cuda/cudnn.jl +++ b/test/cuda/cudnn.jl @@ -1,5 +1,5 @@ using Flux, CuArrays, Test -using Flux: forward +using Flux: pullback @testset "CUDNN BatchNorm" begin @testset "4D Input" begin @@ -8,8 +8,8 @@ using Flux: forward cx = gpu(x) cm = gpu(m) - y, back = forward((m, x) -> m(x), m, x) - cy, cback = forward((m, x) -> m(x), cm, cx) + y, back = pullback((m, x) -> m(x), m, x) + cy, cback = pullback((m, x) -> m(x), cm, cx) @test cpu(cy) ≈ y @@ -28,8 +28,8 @@ using Flux: forward cx = gpu(x) cm = gpu(m) - y, back = forward((m, x) -> m(x), m, x) - cy, cback = forward((m, x) -> m(x), cm, cx) + y, back = pullback((m, x) -> m(x), m, x) + cy, cback = pullback((m, x) -> m(x), cm, cx) @test cpu(cy) ≈ y diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl index e417ea58..7753837a 100644 --- a/test/cuda/curnn.jl +++ b/test/cuda/curnn.jl @@ -1,5 +1,5 @@ using Flux, CuArrays, Test -using Flux: forward +using Flux: pullback @testset for R in [RNN, GRU, LSTM] m = R(10, 5) |> gpu @@ -13,7 +13,7 @@ end @testset "RNN" begin @testset for R in [RNN, GRU, LSTM], batch_size in (1, 5) rnn = R(10, 5) - curnn = mapleaves(gpu, rnn) + curnn = fmap(gpu, rnn) Flux.reset!(rnn) Flux.reset!(curnn) @@ -22,8 +22,8 @@ end rand(10, batch_size) cux = gpu(x) - y, back = forward((r, x) -> r(x), rnn, x) - cuy, cuback = forward((r, x) -> r(x), curnn, cux) + y, back = pullback((r, x) -> r(x), rnn, x) + cuy, cuback = pullback((r, x) -> r(x), curnn, cux) @test y ≈ collect(cuy) @test haskey(Flux.CUDA.descs, curnn.cell) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index cda0cc59..22a5d283 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,7 +1,7 @@ using Flux, Test, Statistics -using Zygote: forward +using Zygote: pullback -trainmode(f, x...) = forward(f, x...)[1] +trainmode(f, x...) = pullback(f, x...)[1] trainmode(f) = (x...) -> trainmode(f, x...) @testset "Dropout" begin @@ -42,6 +42,8 @@ end let m = BatchNorm(2), x = [1.0 3.0 5.0; 2.0 4.0 6.0] + @test length(params(m)) == 2 + @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) # initial m.σ is 1 @@ -109,7 +111,9 @@ end expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) # begin tests let m = InstanceNorm(2), sizes = (3, 2, 2), - x = reshape(collect(1:prod(sizes)), sizes) + x = reshape(collect(1:prod(sizes)), sizes) + + @test length(params(m)) == 2 x = Float64.(x) @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) @@ -192,7 +196,9 @@ end squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions let m = GroupNorm(4,2), sizes = (3,4,2), - x = reshape(collect(1:prod(sizes)), sizes) + x = reshape(collect(1:prod(sizes)), sizes) + + @test length(params(m)) == 2 x = Float64.(x) @test m.β == [0, 0, 0, 0] # initβ(32) @test m.γ == [1, 1, 1, 1] # initγ(32) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index b853fc19..9e01af07 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -55,7 +55,7 @@ const ϵ = 1e-7 y = rand(T, 2) ŷ = rand(T, 2) for f in (mse, crossentropy, logitcrossentropy) - fwd, back = Flux.forward(f, ŷ, y) + fwd, back = Flux.pullback(f, ŷ, y) @test fwd isa T @test eltype(back(one(T))[1]) == T end diff --git a/test/utils.jl b/test/utils.jl index 3a840261..18a57139 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -83,7 +83,6 @@ end # Self-referential array. Just want params, no stack overflow pls. r = Any[nothing,m] - Flux.children(a::Vector{Any}) = Tuple(a) r[1] = r @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)] end