Merge branch 'master' into depthwiseconv
This commit is contained in:
commit
4d17a1a809
|
@ -31,8 +31,8 @@ Flux has powerful high-level features, and common architectures can be defined i
|
|||
```julia
|
||||
model = Chain(
|
||||
Dense(768, 128, σ),
|
||||
LSTM(128, 256)
|
||||
LSTM(256, 128)
|
||||
LSTM(128, 256),
|
||||
LSTM(256, 128),
|
||||
Dense(128, 10),
|
||||
softmax)
|
||||
|
||||
|
|
|
@ -129,7 +129,7 @@ linear2 = linear(3, 2)
|
|||
|
||||
model(x) = linear2(σ.(linear1(x)))
|
||||
|
||||
model(x) # => 2-element vector
|
||||
model(rand(5)) # => 2-element vector
|
||||
```
|
||||
|
||||
Another (equivalent) way is to create a struct that explicitly represents the affine layer.
|
||||
|
|
|
@ -20,7 +20,7 @@ struct Call{F,As<:Tuple}
|
|||
args::As
|
||||
end
|
||||
|
||||
Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
|
||||
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
||||
Call() = Call(nothing, ())
|
||||
|
||||
# When deserialising, the object_id changes
|
||||
|
@ -46,7 +46,14 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
|||
|
||||
function _forward end
|
||||
|
||||
function track(f, xs...; kw...)
|
||||
function track(f::F, xs...) where F
|
||||
y, back = _forward(f, xs...)
|
||||
ts = map(tracker, xs)
|
||||
c = Call(back, ts)
|
||||
track(c, y)
|
||||
end
|
||||
|
||||
function track_kw(f::F, xs...; kw...) where F
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
|
|
@ -101,7 +101,7 @@ Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
|||
end
|
||||
end
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
|
@ -324,9 +324,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(depthwiseconv, x, w; kw...)
|
||||
|
||||
@grad depthwiseconv(x, w; kw...) =
|
||||
depthwiseconv(data(x), data(w); kw...),
|
||||
|
@ -334,9 +334,9 @@ depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x
|
|||
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
|
@ -344,14 +344,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
|||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
|
|
|
@ -152,3 +152,13 @@ function gradient(f, args...)
|
|||
end
|
||||
|
||||
derivative(f, x) = gradient(f, x)[1]
|
||||
|
||||
# Non-nesting versions
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
grad.(xs)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue