diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index d92388e1..cb0c6615 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense +Conv2D ``` ## Recurrent Layers diff --git a/src/Flux.jl b/src/Flux.jl index 526d6bb8..2acdb177 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,13 +7,14 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, +export Chain, Dense, RNN, LSTM, Conv2D, Dropout, LayerNorm, BatchNorm, SGD, ADAM, Momentum, Nesterov, AMSGrad, param, params, mapleaves using NNlib -export σ, sigmoid, relu, leakyrelu, elu, swish, softmax +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, + conv2d, maxpool2d, avgpool2d include("tracker/Tracker.jl") using .Tracker @@ -27,6 +28,7 @@ include("treelike.jl") include("layers/stateless.jl") include("layers/basic.jl") +include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl new file mode 100644 index 00000000..e267510b --- /dev/null +++ b/src/layers/conv.jl @@ -0,0 +1,33 @@ +""" + Conv2D(size, in=>out) + Conv2d(size, in=>out, relu) + +Standard convolutional layer. `size` should be a tuple like `(2, 2)`. +`in` and `out` specify the number of input and output channels respectively. + +Data should be stored in HWCN order. In other words, a 100×100 RGB image would +be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. + +Takes the keyword arguments `pad` and `stride`. +""" +struct Conv2D{F,A} + σ::F + weight::A + stride::Int + pad::Int +end + +Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = initn, stride = 1, pad = 0) = + Conv2D(σ, param(initn(k..., ch...)), stride, pad) + +Flux.treelike(Conv2D) + +(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad)) + +function Base.show(io::IO, l::Conv2D) + print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")") + print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end diff --git a/src/onehot.jl b/src/onehot.jl index f94fb93e..4f121958 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -18,7 +18,9 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] +Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 39810069..b4cd27c6 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -12,16 +12,17 @@ function scan(x::TrackedArray) return end -back(c::Call, Δ) = back(c.func, Δ, c.args...) -back(::Call{Void}, Δ) = nothing +back_(f, y, args...) = back(f, args...) +back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) +back_(::Call{Void}, y, Δ) = nothing function back(x::TrackedArray, Δ) ref = x.ref -= 1 if isdefined(x, :grad) x.grad .+= Δ - ref == 0 && back(x.f, x.grad) + ref == 0 && back_(x.f, x.data, x.grad) else - ref == 0 && back(x.f, Δ) + ref == 0 && back_(x.f, x.data, Δ) end return end @@ -35,6 +36,9 @@ end # Interface methods +# TODO: if an error occurs in `back` the refcounts will be broken +# and `back` will silently fail to update. + function back!(x::TrackedArray, Δ) scan(x) back(x, Δ) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index f3221bd8..2dc25e52 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys) @back(ys, Δ[size(xs,1)+1:end, i...]) end +Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = + TrackedArray(Call(reshape, xs, dims...)) + +back(::typeof(reshape), Δ, xs::TrackedArray, _...) = + back(xs, reshape(Δ, size(xs))) + # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) @@ -123,12 +129,36 @@ end # NNlib -import NNlib: softmax, ∇softmax +using NNlib +import NNlib: softmax, ∇softmax, conv2d, pool softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) +# TODO: can store kwargs efficiently in namedtuples +_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad) + +conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) +conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) +conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) = + TrackedArray(Call(_conv2d, x, w, stride, padding)) + +function back(::typeof(_conv2d), Δ, x, w, stride, pad) + @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad)) + @back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad)) +end + +_pool(x, k, mode) = pool(x, window = k, mode = mode) + +pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) = + TrackedArray(Call(_pool, x, window, mode)) + +back_(::typeof(_pool), y, Δ, x, k, mode) = + back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode)) + # Broadcasting using ForwardDiff: Dual, partials diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 68211aa3..cbcd3ad8 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...) return grads end -gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6)) +gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5)) diff --git a/src/utils.jl b/src/utils.jl index afe926d9..bba3e416 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,8 +4,6 @@ initn(dims...) = randn(dims...)/100 glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims))) glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims))) -flatten(xs) = reshape(xs, size(xs, 1), :) - unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...) diff --git a/test/tracker.jl b/test/tracker.jl index ac031915..dc11420b 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,5 +1,6 @@ using Flux.Tracker, Base.Test, NNlib using Flux.Tracker: gradcheck +using NNlib gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...) gradtest(f, dims...) = gradtest(f, rand.(dims)...) @@ -45,4 +46,8 @@ end 2y + x end +@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2)) +@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2)) +@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2)) + end #testset