Merge branch 'master' into depthwiseconv

This commit is contained in:
Avik Pal 2018-08-03 19:41:50 +05:30 committed by GitHub
commit 4d17a1a809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 14 deletions

View File

@ -31,8 +31,8 @@ Flux has powerful high-level features, and common architectures can be defined i
```julia
model = Chain(
Dense(768, 128, σ),
LSTM(128, 256)
LSTM(256, 128)
LSTM(128, 256),
LSTM(256, 128),
Dense(128, 10),
softmax)

View File

@ -129,7 +129,7 @@ linear2 = linear(3, 2)
model(x) = linear2(σ.(linear1(x)))
model(x) # => 2-element vector
model(rand(5)) # => 2-element vector
```
Another (equivalent) way is to create a struct that explicitly represents the affine layer.

View File

@ -20,7 +20,7 @@ struct Call{F,As<:Tuple}
args::As
end
Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
Call() = Call(nothing, ())
# When deserialising, the object_id changes
@ -46,7 +46,14 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end
function track(f, xs...; kw...)
function track(f::F, xs...) where F
y, back = _forward(f, xs...)
ts = map(tracker, xs)
c = Call(back, ts)
track(c, y)
end
function track_kw(f::F, xs...; kw...) where F
y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y)
end

View File

@ -101,7 +101,7 @@ Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
end
end
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
repeat(data(xs), inner = inner, outer = outer), function (Δ)
@ -324,9 +324,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(depthwiseconv, x, w; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(depthwiseconv, x, w; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(depthwiseconv, x, w; kw...)
@grad depthwiseconv(x, w; kw...) =
depthwiseconv(data(x), data(w); kw...),
@ -334,9 +334,9 @@ depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
@grad conv(x, w; kw...) =
conv(data(x), data(w); kw...),
@ -344,14 +344,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
@grad function maxpool(x, k; kw...)
y = maxpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
@grad function meanpool(x, k; kw...)
y = meanpool(data(x), k; kw...)

View File

@ -152,3 +152,13 @@ function gradient(f, args...)
end
derivative(f, x) = gradient(f, x)[1]
# Non-nesting versions
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end