Merge branch 'master' into HEAD

This commit is contained in:
Mike J Innes 2018-02-13 14:45:37 +00:00
commit 2f29733888
26 changed files with 910 additions and 233 deletions

View File

@ -3,8 +3,13 @@ DataFlow 0.2.1
Juno
MacroTools 0.3.3
NNlib
ForwardDiff 0.5.0
Requires
Adapt
GZip
Colors
# AD
ForwardDiff 0.5.0
DiffRules
SpecialFunctions
NaNMath

View File

@ -10,6 +10,7 @@ makedocs(modules=[Flux, NNlib],
"Building Models" =>
["Basics" => "models/basics.md",
"Recurrence" => "models/recurrence.md",
"Regularisation" => "models/regularisation.md",
"Model Reference" => "models/layers.md"],
"Training Models" =>
["Optimisers" => "training/optimisers.md",

View File

@ -0,0 +1,47 @@
# Regularisation
Applying regularisation to model parameters is straightforward. We just need to
apply an appropriate regulariser, such as `norm`, to each model parameter and
add the result to the overall loss.
For example, say we have a simple regression.
```julia
m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y)
```
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
```julia
penalty() = norm(m.W) + norm(m.b)
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
```
When working with layers, Flux provides the `params` function to grab all
parameters at once. We can easily penalise everything with `sum(norm, params)`.
```julia
julia> params(m)
2-element Array{Any,1}:
param([0.355408 0.533092; … 0.430459 0.171498])
param([0.0, 0.0, 0.0, 0.0, 0.0])
julia> sum(norm, params(m))
26.01749952921026 (tracked)
```
Here's a larger example with a multi-layer perceptron.
```julia
m = Chain(
Dense(28^2, 128, relu),
Dense(128, 32, relu),
Dense(32, 10), softmax)
ps = params(m)
loss(x, y) = crossentropy(m(x), y) + sum(norm, ps)
loss(rand(28^2), rand(10))
```

View File

@ -19,7 +19,7 @@ export σ, sigmoid, logσ, logsigmoid, relu, leakyrelu, elu, swish, softmax, log
include("tracker/Tracker.jl")
using .Tracker
export Tracker
import .Tracker: data, value
import .Tracker: data
include("optimise/Optimise.jl")
using .Optimise
@ -34,6 +34,10 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalisation.jl")
include("jit/JIT.jl")
include("data/Data.jl")
@require CuArrays include("cuda/cuda.jl")
end # module

7
src/cuda/cuda.jl Normal file
View File

@ -0,0 +1,7 @@
module CUDA
using CuArrays
CuArrays.cudnn_available() && include("cudnn.jl")
end

368
src/cuda/cudnn.jl Normal file
View File

@ -0,0 +1,368 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
cudnnDataType, TensorDesc, FilterDesc
mutable struct DropoutDesc
ptr::Ptr{Void}
states::CuVector{UInt8}
end
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL]
s = Csize_t[0]
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
desc,libcudnn_handle[],ρ,states,length(states),seed)
finalizer(desc, x ->
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
return desc
end
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
const LINEAR_INPUT = 0
const SKIP_INPUT = 1
const UNIDIRECTIONAL = 0
const BIDIRECTIONAL = 1
const RNN_ALGO_STANDARD = 0
const RNN_ALGO_PERSIST_STATIC = 1
const RNN_ALGO_PERSIST_DYNAMIC = 2
# param layout:
# RNN: [weight, bias] × [input, hidden]
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
bias = w[length(wx)+length(wh) + (1:hidden*n)]
(wx, wh), bias
end
mutable struct RNNDesc{T}
mode::Int
input::Int
hidden::Int
params::CuVector{T}
weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T}
ptr::Ptr{Void}
end
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input)
size = Csize_t[0]
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T)
end
ngates(mode) = [1, 1, 4, 3][mode+1]
ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL]
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd, x ->
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
return rd
end
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[])
end
const workspace = [CuVector{UInt8}(1)]
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0]
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[])
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing) where T
if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Csize_t),
libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
end
end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Void) = C_NULL, C_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
end
# TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2)
seqLength = 1
xdesc = xDesc(x)
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ho = similar(h)
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
xdesc, x,
hDesc(h)...,
hDesc(c)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
ydesc, y,
hDesc(ho)...,
hDesc(co)...,
workspace, reserve)
result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
dy = dy_ isa Integer ? zeros(y) : dy_
yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h)
dc = c == nothing ? nothing : similar(c)
cudnnRNNBackwardData(rnn, 1,
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], reserve)
return c == nothing ? (dx, dh) : (dx, dh, dc)
end
backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Void}}, Ptr{T}, #x
Ptr{Void}, Ptr{T}, #hx
Ptr{Ptr{Void}}, Ptr{T}, #y
Ptr{Void}, Csize_t, #ws
Ptr{Void}, Ptr{T}, #dw
Ptr{Void}, Csize_t), #rs
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zeros(rnn.params)
cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], reserve)
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end
# Interface
import ..Flux: Flux, relu
import ..Flux.Tracker: TrackedArray
using CUDAnative
using CuArrays: @cuindex, cudims
function copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, Flux.data(m.Wi))
copy_transpose!(Wh, Flux.data(m.Wh))
copy_transpose!(d.bias, Flux.data(m.b))
return
end
function RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ?
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
m isa CuGRU ? GRU : LSTM
r = RNNDesc{T}(mode, i, h)
return r
end
const descs = WeakKeyDict()
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
mutable struct RNNCall{R}
rnn::R
reserve::CuVector{UInt8}
RNNCall{R}(rnn::R) where R = new(rnn)
end
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
function (c::RNNCall)(args...)
rs, result = forwardTrain(desc(c.rnn), args...)
c.reserve = rs
return result
end
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h[1], h[2]) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
function accum_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] += src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
return dst
end
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
y, ho = y_
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
@back(x, dx)
@back(h, unbroadcast(h, dh))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
# We don't have to make this assumption, it's just slightly more complex.
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
end
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
y, ho, co = y_
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
@back(x, dx)
@back(h, unbroadcast(h, dh))
@back(c, unbroadcast(h, dc))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
end

8
src/jit/JIT.jl Normal file
View File

@ -0,0 +1,8 @@
module JIT
include("shapes.jl")
include("inplace.jl")
include("trace.jl")
include("lib.jl")
end

11
src/jit/inplace.jl Normal file
View File

@ -0,0 +1,11 @@
mutable struct Cached{F,A}
f::F
buffer::A
end
function (c::Cached)(args...)
sh = shape(c.f, shape(args)...)
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
y = restructure(sh, c.buffer)
inplace!(c.f, y, args...)
end

9
src/jit/lib.jl Normal file
View File

@ -0,0 +1,9 @@
# Primitive definitions
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
A_mul_B!(C, A, B)
shape(::typeof(broadcast), f, xs...) =
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)

37
src/jit/shapes.jl Normal file
View File

@ -0,0 +1,37 @@
struct Shape{T,N}
dims::NTuple{N,Int}
end
VecShape{T} = Shape{T,1}
MatShape{T} = Shape{T,2}
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
Base.size(s::Shape) = s.dims
Base.size(s::Shape, n) = s.dims[n]
Base.length(s::Shape) = prod(s.dims)
Base.eltype(s::Shape{T}) where T = T
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
function Base.show(io::IO, s::Shape{T}) where T
print(io, "Shape{$T}(")
join(io, s.dims, ", ")
print(io, ")")
end
shape(x) = typeof(x)
shape(x::Shape) = x
shape(x::Tuple) = shape.(x)
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
bytes(s::Shape) = sizeof(s)
bytes(x::Tuple) = sum(bytes.(x))
# Recover structure from byte buffers
# Make sure to hold on to the parent buffer for the lifetime of the data.
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
reshape(reinterpret(T, buf), size(sh))
end

25
src/jit/trace.jl Normal file
View File

@ -0,0 +1,25 @@
# This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow
using DataFlow: inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
function graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(inputs, x)
cache[x] =
i > 0 ? inputnode(i) :
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
constant(x)
end
function trace(f, args...)
inputs = param.(args)
graph(f(inputs...), inputs...)
end

View File

@ -113,15 +113,15 @@ function (BN::BatchNorm)(x)
else
T = eltype(x)
ϵ = T(BN.ϵ)
ϵ = data(convert(T, BN.ϵ))
m = size(x, 2) # batch size
μ = mean(x, 2)
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
# update moving mean/std
mtm = T(BN.momentum)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
end
λ.(γ .* ((x .- μ) ./ σ) .+ β)

View File

@ -1,7 +1,6 @@
# TODO: broadcasting cat
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
gate(h, n) = (1:h) + h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
# Stateful recurrence
@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs)))
# Vanilla RNN
struct RNNCell{D,V}
d::D
mutable struct RNNCell{F,A,V}
σ::F
Wi::A
Wh::A
b::V
h::V
end
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(initn(out)))
function (m::RNNCell)(h, x)
h = m.d(combine(x, h))
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
return h, h
end
@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h
treelike(RNNCell)
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM
struct LSTMCell{D1,D2,V}
forget::D1
input::D1
output::D1
cell::D2
h::V; c::V
mutable struct LSTMCell{A,V}
Wi::A
Wh::A
b::V
h::V
c::V
end
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
Dense(in+out, out, tanh, initW = initW, initb = initb),
param(initW(out)), param(initW(out)))
cell.forget.b.data .= 1
function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(initn(out)), param(initn(out)))
cell.b.data[gate(out, 2)] = 1
return cell
end
function (m::LSTMCell)(h_, x)
h, c = h_
x = combine(x, h)
forget, input, output, cell =
m.forget(x), m.input(x), m.output(x), m.cell(x)
h, c = h_ # TODO: nicer syntax on 0.7
b, o = m.b, size(h, 1)
g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1))
forget = σ.(gate(g, o, 2))
cell = tanh.(gate(g, o, 3))
output = σ.(gate(g, o, 4))
c = forget .* c .+ input .* cell
h = output .* tanh.(c)
return (h, c), h
h = output .* tanh.(c)
return (h, c), h
end
hidden(m::LSTMCell) = (m.h, m.c)
treelike(LSTMCell)
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
"""
LSTM(in::Integer, out::Integer, σ = tanh)
@ -153,38 +161,33 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
# GRU
struct GRUCell{D1,D2,V}
update::D1
reset::D1
candidate::D2
mutable struct GRUCell{A,V}
Wi::A
Wh::A
b::V
h::V
end
function GRUCell(in, out)
cell = GRUCell(Dense(in+out, out, σ),
Dense(in+out, out, σ),
Dense(in+out, out, tanh),
param(initn(out)))
return cell
end
GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zeros(out*3)), param(initn(out)))
function (m::GRUCell)(h, x)
x = combine(x, h)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*
return h, h
b, o = m.b, size(h, 1)
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h = (1.-z).* .+ z.*h
return h, h
end
hidden(m::GRUCell) = m.h
treelike(GRUCell)
Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(",
size(m.update.W, 2) - size(m.update.W, 1), ", ",
size(m.update.W, 1), ')')
Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
"""
GRU(in::Integer, out::Integer, σ = tanh)

View File

@ -1,5 +1,5 @@
using Juno
using Flux.Tracker: back!, value
using Flux.Tracker: back!
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(value(l)) && error("Loss is Inf")
isnan(value(l)) && error("Loss is NaN")
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
back!(l)
opt()
cb() == :stop && break

View File

@ -2,8 +2,12 @@ module Tracker
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
data(x) = x
istracked(x) = false
tracker(x) = nothing
istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x))
struct Call{F,As<:Tuple}
func::F
@ -14,109 +18,39 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...)
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
mutable struct Tracked{T}
ref::UInt32
f::Call
data::A
grad::A
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
isleaf::Bool
data::T
grad::T
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
end
TrackedScalar{T,A} = TrackedArray{T,0,A}
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x)
track(f::Call, x) = Tracked(f, x)
track(f::Call) = track(f, f())
track(f, xs...) = track(Call(f, xs...))
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
# TODO decide if keeping both data and value. The problem is TrackedScalar
value(x) = x
value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[]
Base.:(==)(x::TrackedArray, y) = value(x) == y
Base.:(==)(y, x::TrackedArray) = y == value(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
if repr
print(io, "param(")
Base.showarray(io, data(X), true)
print(io, ")")
else
header && print(io, "Tracked ")
Base.showarray(io, data(X), false, header = header)
end
end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad
include("back.jl")
include("lib.jl")
include("scalar.jl")
include("array.jl")
include("numeric.jl")
using DataFlow
using DataFlow: inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(inputs, x)
cache[x] =
i > 0 ? inputnode(i) :
isleaf(x) ? constant(x) :
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
end
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
function graph(f, args...)
inputs = param.(args)
_graph(f(inputs...), inputs...)
end
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
import Adapt.adapt
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end

View File

@ -1,45 +1,100 @@
toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
tracker::Tracked{A}
data::A
grad::A
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end
unarray(xs) = xs
unarray(xs::AbstractArray{T,0} where T) = xs[]
tracker(x::TrackedArray) = x.tracker
Base.getindex(xs::TrackedArray, i...) =
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
if repr
print(io, "param(")
Base.showarray(io, data(X), true)
print(io, ")")
else
header && print(io, "Tracked ")
Base.showarray(io, data(X), false, header = header)
end
end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Use back!(x, Δ)")
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
Base.:(==)(x::TrackedArray, y) = data(x) == y
Base.:(==)(y, x::TrackedArray) = y == data(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
# Array Stdlib
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs.data)
Δ′[i...] = unarray(Δ)
Δ′[i...] = Δ
@back(xs, Δ′)
end
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
Base.:-(xs::TrackedArray) = track(-, xs)
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::TrackedVector...) = TrackedArray(Call(vcat, a, b...))
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = TrackedArray(Call(vcat, a, b...))
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = TrackedArray(Call(vcat, a, b...))
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
function back(::typeof(vcat), Δ, xs...)
i = Base.tail(map(_ -> :, size(Δ)))
@ -51,32 +106,32 @@ function back(::typeof(vcat), Δ, xs...)
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
TrackedArray(Call(reshape, xs, dims...))
track(reshape, xs, dims...)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
# Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
Base.sum(xs::TrackedScalar, dim...) = xs
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
Base.sum(xs::TrackedArray) = track(sum, xs)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
function back(::typeof(dot), Δ, xs, ys)
@back(xs, Δ.*ys)
@back(ys, Δ.*xs)
@back(xs, Δ.*data(ys))
@back(ys, Δ.*data(xs))
end
# Hacks to get std working
@ -85,29 +140,34 @@ Base.std(x::TrackedArray; mean = Base.mean(x)) =
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
Base.norm(x::TrackedArray, p::Real = 2) =
p == 1 ? sum(abs.(x)) :
p == 2 ? sqrt(sum(abs2.(x))) :
error("$p-norm not supported")
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
# BLAS
Base.diagm(x::TrackedVector) = TrackedArray(Call(diagm, x))
Base.diagm(x::TrackedVector) = track(diagm, x)
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
for f in :[*, Ac_mul_B, A_mul_Bc].args
@eval begin
import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
$f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
end
end
@ -141,11 +201,11 @@ end
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, logσ, ∇logσ, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
softmax(xs::TrackedArray) = track(softmax, xs)
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
@ -157,11 +217,11 @@ back(::typeof(logσ), Δ, xs) = @back(xs, ∇logσ(Δ, data(xs)))
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
track(_conv2d, x, w, stride, padding)
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
track(_conv2d, x, w, stride, padding)
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
track(_conv2d, x, w, stride, padding)
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
@ -171,7 +231,7 @@ end
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
TrackedArray(Call(_pool, x, window, padding, mode))
track(_pool, x, window, padding, mode)
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
@ -189,23 +249,24 @@ end
dualify(xs, n) = xs
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
# TrackedArray(Call(Broadcasted(f, broadcast(f, dargs...)), args...))
# Works around a 0.6 type inference issue
b = Broadcasted(f, out)
TrackedArray(Call(b, args...), b())
track(Call(b, args...), b())
end
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
unbroadcast(x, Δ) =
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
unbroadcast(x::Number, Δ) = sum(Δ)
function getpartial(Δ, x, i)
@inbounds p = getindex(partials(x), i)
return Δ * p

View File

@ -1,25 +1,38 @@
scan(x) = nothing
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
function scan(x::TrackedArray)
function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
else
isdefined(x, :grad) || (x.grad = zeros(x.data))
isdefined(x, :grad) || (x.grad = init_grad(x.data))
end
return
end
function scan(x)
istracked(x) && scan(tracker(x))
return
end
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back(x::TrackedArray, Δ)
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ)
x.isleaf && (accum!(x.grad, Δ); return)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
x.grad = accum!(x.grad, Δ)
ref == 0 && back_(x.f, x.data, x.grad)
else
ref == 0 && back_(x.f, x.data, Δ)
@ -27,6 +40,9 @@ function back(x::TrackedArray, Δ)
return
end
back(x, Δ) = back(tracker(x), Δ)
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
macro back(x, Δ)
quote
x = $(esc(x))
@ -39,9 +55,9 @@ end
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
function back!(x::TrackedArray, Δ)
function back!(x::Tracked, Δ)
scan(x)
back(x, Δ)
end
back!(x::TrackedScalar) = back!(x, 1)
back!(x, Δ) = back!(tracker(x), Δ)

86
src/tracker/scalar.jl Normal file
View File

@ -0,0 +1,86 @@
struct TrackedReal{T<:Real} <: Real
tracker::Tracked{T}
end
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
back!(x::TrackedReal) = back!(x, 1)
function Base.show(io::IO, x::TrackedReal)
show(io, data(x))
print(io, " (tracked)")
end
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)}
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
@eval begin
$M.$f(a::TrackedReal) = track($M.$f, a)
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
end
end
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
@eval begin
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
@back(a, Δ * $da)
@back(b, Δ * $db)
end
end
end
# Tuples
struct TrackedTuple{T<:Tuple}
tracker::Tracked{T}
end
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
print(io, " (tracked)")
end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))

View File

@ -35,7 +35,10 @@ end
function params(m)
ps = []
prefor(p -> p isa TrackedArray && push!(ps, p), m)
prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) &&
!(p in ps) && push!(ps, p),
m)
return ps
end

View File

@ -72,17 +72,6 @@ end
# Other
function accuracy(m, data)
n = 0
correct = 0
for (x, y) in data
x, y = tobatch.((x, y))
n += size(x, 1)
correct += sum(argmax(m(x)) .== argmax(y))
end
return correct/n
end
"""
Returns a function that when invoked, will only be triggered at most once
during `timeout` seconds. Normally, the throttled function will run

View File

@ -21,3 +21,5 @@ cm = cu(m)
@test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
end
CuArrays.cudnn_available() && include("cudnn.jl")

31
test/cuda/cudnn.jl Normal file
View File

@ -0,0 +1,31 @@
using Flux, CuArrays, Base.Test
info("Testing Flux/CUDNN")
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
x = param(rand(10,5))
cux = cu(x)
rnn = R(10, 5)
curnn = mapleaves(cu, rnn)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
@test y.data collect(cuy.data)
@test haskey(Flux.CUDA.descs, curnn.cell)
Δ = randn(size(y))
Flux.back!(y, Δ)
Flux.back!(cuy, cu(Δ))
@test x.grad collect(cux.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad)
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad)
@test rnn.cell.b.grad collect(curnn.cell.b.grad)
@test rnn.cell.h.grad collect(curnn.cell.h.grad)
if isdefined(rnn.cell, :c)
@test rnn.cell.c.grad collect(curnn.cell.c.grad)
end
end
end

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])

View File

@ -1,5 +1,7 @@
using Flux, Base.Test
srand(0)
@testset "Flux" begin
include("utils.jl")
@ -10,7 +12,7 @@ include("optimise.jl")
include("data.jl")
if Base.find_in_path("CuArrays") nothing
include("cuarrays.jl")
include("cuda/cuda.jl")
end
end

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: gradcheck
using Flux.Tracker: TrackedReal, gradcheck
using NNlib
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -15,7 +15,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@ -47,6 +47,7 @@ end
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))
@test gradtest(rand(5)) do x
y = x.^2
@ -59,4 +60,24 @@ end
@test (param([1,2,3]) .< 2) == [true, false, false]
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l)
@test x.grad == [8]
end
@testset "Fallbacks" begin
xs = param([1 2; 3 4])
@test similar(xs) isa Matrix{Float64}
# Remove this test if we do LowerTriangular properly
L = LowerTriangular(xs)
@test L*L' isa Matrix{TrackedReal{Float64}}
end
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
end #testset

View File

@ -79,3 +79,10 @@ end
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
end
end
@testset "Params" begin
m = Dense(10, 5)
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end