commit
6b6974e14a
@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks.
|
|||||||
```@docs
|
```@docs
|
||||||
Chain
|
Chain
|
||||||
Dense
|
Dense
|
||||||
|
Conv2D
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recurrent Layers
|
## Recurrent Layers
|
||||||
|
@ -7,13 +7,14 @@ module Flux
|
|||||||
using Juno, Requires
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM, Conv2D,
|
||||||
Dropout, LayerNorm, BatchNorm,
|
Dropout, LayerNorm, BatchNorm,
|
||||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
|
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
|
||||||
|
conv2d, maxpool2d, avgpool2d
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
@ -27,6 +28,7 @@ include("treelike.jl")
|
|||||||
|
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
|
include("layers/conv.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
|
|
||||||
|
33
src/layers/conv.jl
Normal file
33
src/layers/conv.jl
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
Conv2D(size, in=>out)
|
||||||
|
Conv2d(size, in=>out, relu)
|
||||||
|
|
||||||
|
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||||
|
`in` and `out` specify the number of input and output channels respectively.
|
||||||
|
|
||||||
|
Data should be stored in HWCN order. In other words, a 100×100 RGB image would
|
||||||
|
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
|
"""
|
||||||
|
struct Conv2D{F,A}
|
||||||
|
σ::F
|
||||||
|
weight::A
|
||||||
|
stride::Int
|
||||||
|
pad::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
|
init = initn, stride = 1, pad = 0) =
|
||||||
|
Conv2D(σ, param(initn(k..., ch...)), stride, pad)
|
||||||
|
|
||||||
|
Flux.treelike(Conv2D)
|
||||||
|
|
||||||
|
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad))
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::Conv2D)
|
||||||
|
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")")
|
||||||
|
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4))
|
||||||
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
|
print(io, ")")
|
||||||
|
end
|
@ -18,7 +18,9 @@ end
|
|||||||
|
|
||||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
||||||
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||||
|
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||||
|
|
||||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||||
|
|
||||||
|
@ -12,16 +12,17 @@ function scan(x::TrackedArray)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
back_(f, y, args...) = back(f, args...)
|
||||||
back(::Call{Void}, Δ) = nothing
|
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||||
|
back_(::Call{Void}, y, Δ) = nothing
|
||||||
|
|
||||||
function back(x::TrackedArray, Δ)
|
function back(x::TrackedArray, Δ)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
x.grad .+= Δ
|
x.grad .+= Δ
|
||||||
ref == 0 && back(x.f, x.grad)
|
ref == 0 && back_(x.f, x.data, x.grad)
|
||||||
else
|
else
|
||||||
ref == 0 && back(x.f, Δ)
|
ref == 0 && back_(x.f, x.data, Δ)
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -35,6 +36,9 @@ end
|
|||||||
|
|
||||||
# Interface methods
|
# Interface methods
|
||||||
|
|
||||||
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||||
|
# and `back` will silently fail to update.
|
||||||
|
|
||||||
function back!(x::TrackedArray, Δ)
|
function back!(x::TrackedArray, Δ)
|
||||||
scan(x)
|
scan(x)
|
||||||
back(x, Δ)
|
back(x, Δ)
|
||||||
|
@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys)
|
|||||||
@back(ys, Δ[size(xs,1)+1:end, i...])
|
@back(ys, Δ[size(xs,1)+1:end, i...])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||||
|
TrackedArray(Call(reshape, xs, dims...))
|
||||||
|
|
||||||
|
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||||
|
back(xs, reshape(Δ, size(xs)))
|
||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||||
@ -123,12 +129,36 @@ end
|
|||||||
|
|
||||||
# NNlib
|
# NNlib
|
||||||
|
|
||||||
import NNlib: softmax, ∇softmax
|
using NNlib
|
||||||
|
import NNlib: softmax, ∇softmax, conv2d, pool
|
||||||
|
|
||||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||||
|
|
||||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||||
|
|
||||||
|
# TODO: can store kwargs efficiently in namedtuples
|
||||||
|
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
|
||||||
|
|
||||||
|
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||||
|
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||||
|
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||||
|
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||||
|
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
|
||||||
|
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||||
|
|
||||||
|
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
|
||||||
|
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||||
|
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||||
|
end
|
||||||
|
|
||||||
|
_pool(x, k, mode) = pool(x, window = k, mode = mode)
|
||||||
|
|
||||||
|
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0) =
|
||||||
|
TrackedArray(Call(_pool, x, window, mode))
|
||||||
|
|
||||||
|
back_(::typeof(_pool), y, Δ, x, k, mode) =
|
||||||
|
back(x, NNlib.pool_grad(data(x), y, Δ, window = k, mode = mode))
|
||||||
|
|
||||||
# Broadcasting
|
# Broadcasting
|
||||||
|
|
||||||
using ForwardDiff: Dual, partials
|
using ForwardDiff: Dual, partials
|
||||||
|
@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
|
|||||||
return grads
|
return grads
|
||||||
end
|
end
|
||||||
|
|
||||||
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))
|
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))
|
||||||
|
@ -4,8 +4,6 @@ initn(dims...) = randn(dims...)/100
|
|||||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||||
|
|
||||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
|
||||||
|
|
||||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
|
||||||
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: gradcheck
|
using Flux.Tracker: gradcheck
|
||||||
|
using NNlib
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||||
@ -45,4 +46,8 @@ end
|
|||||||
2y + x
|
2y + x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||||
|
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
|
||||||
|
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user