Merge branch 'master' into HEAD
This commit is contained in:
commit
2f29733888
7
REQUIRE
7
REQUIRE
|
@ -3,8 +3,13 @@ DataFlow 0.2.1
|
|||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
ForwardDiff 0.5.0
|
||||
Requires
|
||||
Adapt
|
||||
GZip
|
||||
Colors
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
DiffRules
|
||||
SpecialFunctions
|
||||
NaNMath
|
||||
|
|
|
@ -10,6 +10,7 @@ makedocs(modules=[Flux, NNlib],
|
|||
"Building Models" =>
|
||||
["Basics" => "models/basics.md",
|
||||
"Recurrence" => "models/recurrence.md",
|
||||
"Regularisation" => "models/regularisation.md",
|
||||
"Model Reference" => "models/layers.md"],
|
||||
"Training Models" =>
|
||||
["Optimisers" => "training/optimisers.md",
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Regularisation
|
||||
|
||||
Applying regularisation to model parameters is straightforward. We just need to
|
||||
apply an appropriate regulariser, such as `norm`, to each model parameter and
|
||||
add the result to the overall loss.
|
||||
|
||||
For example, say we have a simple regression.
|
||||
|
||||
```julia
|
||||
m = Dense(10, 5)
|
||||
loss(x, y) = crossentropy(softmax(m(x)), y)
|
||||
```
|
||||
|
||||
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
||||
|
||||
```julia
|
||||
penalty() = norm(m.W) + norm(m.b)
|
||||
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
||||
```
|
||||
|
||||
When working with layers, Flux provides the `params` function to grab all
|
||||
parameters at once. We can easily penalise everything with `sum(norm, params)`.
|
||||
|
||||
```julia
|
||||
julia> params(m)
|
||||
2-element Array{Any,1}:
|
||||
param([0.355408 0.533092; … 0.430459 0.171498])
|
||||
param([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
julia> sum(norm, params(m))
|
||||
26.01749952921026 (tracked)
|
||||
```
|
||||
|
||||
Here's a larger example with a multi-layer perceptron.
|
||||
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 128, relu),
|
||||
Dense(128, 32, relu),
|
||||
Dense(32, 10), softmax)
|
||||
|
||||
ps = params(m)
|
||||
|
||||
loss(x, y) = crossentropy(m(x), y) + sum(norm, ps)
|
||||
|
||||
loss(rand(28^2), rand(10))
|
||||
```
|
|
@ -19,7 +19,7 @@ export σ, sigmoid, logσ, logsigmoid, relu, leakyrelu, elu, swish, softmax, log
|
|||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
export Tracker
|
||||
import .Tracker: data, value
|
||||
import .Tracker: data
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
@ -34,6 +34,10 @@ include("layers/conv.jl")
|
|||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
|
||||
include("jit/JIT.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
module CUDA
|
||||
|
||||
using CuArrays
|
||||
|
||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||
|
||||
end
|
|
@ -0,0 +1,368 @@
|
|||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||
cudnnDataType, TensorDesc, FilterDesc
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Void}
|
||||
states::CuVector{UInt8}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
|
||||
|
||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||
d = [C_NULL]
|
||||
s = Csize_t[0]
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
|
||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||
finalizer(desc, x ->
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
return desc
|
||||
end
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||
const LSTM = 2 # LSTM with no peephole connections
|
||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
||||
|
||||
const LINEAR_INPUT = 0
|
||||
const SKIP_INPUT = 1
|
||||
|
||||
const UNIDIRECTIONAL = 0
|
||||
const BIDIRECTIONAL = 1
|
||||
|
||||
const RNN_ALGO_STANDARD = 0
|
||||
const RNN_ALGO_PERSIST_STATIC = 1
|
||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
|
||||
# param layout:
|
||||
# RNN: [weight, bias] × [input, hidden]
|
||||
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
||||
(wx, wh), bias
|
||||
end
|
||||
|
||||
mutable struct RNNDesc{T}
|
||||
mode::Int
|
||||
input::Int
|
||||
hidden::Int
|
||||
params::CuVector{T}
|
||||
weights::NTuple{2,CuMatrix{T}}
|
||||
bias::CuVector{T}
|
||||
ptr::Ptr{Void}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
||||
|
||||
function rnnParamSize(T, r, input)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
||||
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||
return Int(size[])÷sizeof(T)
|
||||
end
|
||||
|
||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||
ngates(r::RNNDesc) = ngates(r.mode)
|
||||
|
||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||
|
||||
dropoutDesc = DropoutDesc(0)
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
return rd
|
||||
end
|
||||
|
||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(1)]
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(bytes))
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||
|
||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, reserve=nothing) where T
|
||||
if reserve == nothing
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace))
|
||||
else
|
||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
end
|
||||
end
|
||||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Void) = C_NULL, C_NULL
|
||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
end
|
||||
|
||||
# TODO: can we just manipulate strides here?
|
||||
# TODO: should use repmat, but this isn't implemented.
|
||||
hBatch(x::AbstractVector, h::CuVector) = h
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||
h = hBatch(x, h_)
|
||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||
@assert size(x, 1) == rnn.input
|
||||
@assert size(h, 1) == rnn.hidden
|
||||
@assert size(x, 2) == size(h, 2)
|
||||
seqLength = 1
|
||||
xdesc = xDesc(x)
|
||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||
ho = similar(h)
|
||||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train == Val{true} ?
|
||||
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
nothing
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve)
|
||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||
return train == Val{true} ? (reserve, result) : result
|
||||
end
|
||||
|
||||
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
||||
forward(rnn, x, h, c, Val{true})
|
||||
|
||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
|
||||
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
# Same as above, any more efficient way?
|
||||
dy = dy_ isa Integer ? zeros(y) : dy_
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], reserve)
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
backwardData(rnn, y, dy, dho, hx, reserve) =
|
||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
||||
|
||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #x
|
||||
Ptr{Void}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #y
|
||||
Ptr{Void}, Csize_t, #ws
|
||||
Ptr{Void}, Ptr{T}, #dw
|
||||
Ptr{Void}, Csize_t), #rs
|
||||
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zeros(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], reserve)
|
||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||
end
|
||||
|
||||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Flux.Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
Wi, Wh = d.weights
|
||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||
copy_transpose!(d.bias, Flux.data(m.b))
|
||||
return
|
||||
end
|
||||
|
||||
function RNNDesc(m::CuRNNs{T}) where T
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
m isa CuGRU ? GRU : LSTM
|
||||
r = RNNDesc{T}(mode, i, h)
|
||||
return r
|
||||
end
|
||||
|
||||
const descs = WeakKeyDict()
|
||||
|
||||
function desc(rnn)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
||||
copyparams!(rnn, d)
|
||||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||
|
||||
mutable struct RNNCall{R}
|
||||
rnn::R
|
||||
reserve::CuVector{UInt8}
|
||||
RNNCall{R}(rnn::R) where R = new(rnn)
|
||||
end
|
||||
|
||||
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
|
||||
|
||||
function (c::RNNCall)(args...)
|
||||
rs, result = forwardTrain(desc(c.rnn), args...)
|
||||
c.reserve = rs
|
||||
return result
|
||||
end
|
||||
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h) :
|
||||
forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(RNNCall(m), x, h[1], h[2]) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
end
|
||||
|
||||
function accum_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] += src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||
y, ho = y_
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
# We don't have to make this assumption, it's just slightly more complex.
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||
y, ho, co = y_
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
@back(c, unbroadcast(h, dc))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
|
@ -0,0 +1,8 @@
|
|||
module JIT
|
||||
|
||||
include("shapes.jl")
|
||||
include("inplace.jl")
|
||||
include("trace.jl")
|
||||
include("lib.jl")
|
||||
|
||||
end
|
|
@ -0,0 +1,11 @@
|
|||
mutable struct Cached{F,A}
|
||||
f::F
|
||||
buffer::A
|
||||
end
|
||||
|
||||
function (c::Cached)(args...)
|
||||
sh = shape(c.f, shape(args)...)
|
||||
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
|
||||
y = restructure(sh, c.buffer)
|
||||
inplace!(c.f, y, args...)
|
||||
end
|
|
@ -0,0 +1,9 @@
|
|||
# Primitive definitions
|
||||
|
||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||
A_mul_B!(C, A, B)
|
||||
|
||||
shape(::typeof(broadcast), f, xs...) =
|
||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
||||
|
||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
|
@ -0,0 +1,37 @@
|
|||
struct Shape{T,N}
|
||||
dims::NTuple{N,Int}
|
||||
end
|
||||
|
||||
VecShape{T} = Shape{T,1}
|
||||
MatShape{T} = Shape{T,2}
|
||||
|
||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
||||
|
||||
Base.size(s::Shape) = s.dims
|
||||
Base.size(s::Shape, n) = s.dims[n]
|
||||
Base.length(s::Shape) = prod(s.dims)
|
||||
Base.eltype(s::Shape{T}) where T = T
|
||||
|
||||
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
|
||||
|
||||
function Base.show(io::IO, s::Shape{T}) where T
|
||||
print(io, "Shape{$T}(")
|
||||
join(io, s.dims, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
shape(x) = typeof(x)
|
||||
shape(x::Shape) = x
|
||||
shape(x::Tuple) = shape.(x)
|
||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||
|
||||
bytes(s::Shape) = sizeof(s)
|
||||
bytes(x::Tuple) = sum(bytes.(x))
|
||||
|
||||
# Recover structure from byte buffers
|
||||
# Make sure to hold on to the parent buffer for the lifetime of the data.
|
||||
|
||||
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
|
||||
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
|
@ -0,0 +1,25 @@
|
|||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||
|
||||
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
function trace(f, args...)
|
||||
inputs = param.(args)
|
||||
graph(f(inputs...), inputs...)
|
||||
end
|
|
@ -113,15 +113,15 @@ function (BN::BatchNorm)(x)
|
|||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = T(BN.ϵ)
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
m = size(x, 2) # batch size
|
||||
μ = mean(x, 2)
|
||||
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = T(BN.momentum)
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ)
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
|
||||
end
|
||||
|
||||
λ.(γ .* ((x .- μ) ./ σ) .+ β)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# TODO: broadcasting cat
|
||||
combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
|
||||
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
|
||||
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
|
||||
gate(h, n) = (1:h) + h*(n-1)
|
||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||
|
||||
# Stateful recurrence
|
||||
|
||||
|
@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs)))
|
|||
|
||||
# Vanilla RNN
|
||||
|
||||
struct RNNCell{D,V}
|
||||
d::D
|
||||
mutable struct RNNCell{F,A,V}
|
||||
σ::F
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
h::V
|
||||
end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
||||
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zeros(out)), param(initn(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
h = m.d(combine(x, h))
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
h = σ.(Wi*x .+ Wh*h .+ b)
|
||||
return h, h
|
||||
end
|
||||
|
||||
|
@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h
|
|||
|
||||
treelike(RNNCell)
|
||||
|
||||
function Base.show(io::IO, m::RNNCell)
|
||||
print(io, "RNNCell(", m.d, ")")
|
||||
function Base.show(io::IO, l::RNNCell)
|
||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
|||
|
||||
# LSTM
|
||||
|
||||
struct LSTMCell{D1,D2,V}
|
||||
forget::D1
|
||||
input::D1
|
||||
output::D1
|
||||
cell::D2
|
||||
h::V; c::V
|
||||
mutable struct LSTMCell{A,V}
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
h::V
|
||||
c::V
|
||||
end
|
||||
|
||||
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
||||
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
||||
param(initW(out)), param(initW(out)))
|
||||
cell.forget.b.data .= 1
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.b.data[gate(out, 2)] = 1
|
||||
return cell
|
||||
end
|
||||
|
||||
function (m::LSTMCell)(h_, x)
|
||||
h, c = h_
|
||||
x′ = combine(x, h)
|
||||
forget, input, output, cell =
|
||||
m.forget(x′), m.input(x′), m.output(x′), m.cell(x′)
|
||||
h, c = h_ # TODO: nicer syntax on 0.7
|
||||
b, o = m.b, size(h, 1)
|
||||
g = m.Wi*x .+ m.Wh*h .+ b
|
||||
input = σ.(gate(g, o, 1))
|
||||
forget = σ.(gate(g, o, 2))
|
||||
cell = tanh.(gate(g, o, 3))
|
||||
output = σ.(gate(g, o, 4))
|
||||
c = forget .* c .+ input .* cell
|
||||
h = output .* tanh.(c)
|
||||
return (h, c), h
|
||||
h′ = output .* tanh.(c)
|
||||
return (h′, c), h′
|
||||
end
|
||||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
treelike(LSTMCell)
|
||||
|
||||
Base.show(io::IO, m::LSTMCell) =
|
||||
print(io, "LSTMCell(",
|
||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||
size(m.forget.W, 1), ')')
|
||||
Base.show(io::IO, l::LSTMCell) =
|
||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
|
||||
|
||||
"""
|
||||
LSTM(in::Integer, out::Integer, σ = tanh)
|
||||
|
@ -153,38 +161,33 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
|||
|
||||
# GRU
|
||||
|
||||
struct GRUCell{D1,D2,V}
|
||||
update::D1
|
||||
reset::D1
|
||||
candidate::D2
|
||||
mutable struct GRUCell{A,V}
|
||||
Wi::A
|
||||
Wh::A
|
||||
b::V
|
||||
h::V
|
||||
end
|
||||
|
||||
function GRUCell(in, out)
|
||||
cell = GRUCell(Dense(in+out, out, σ),
|
||||
Dense(in+out, out, σ),
|
||||
Dense(in+out, out, tanh),
|
||||
param(initn(out)))
|
||||
return cell
|
||||
end
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(initn(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
x′ = combine(x, h)
|
||||
z = m.update(x′)
|
||||
r = m.reset(x′)
|
||||
h̃ = m.candidate(combine(r.*h, x))
|
||||
h = (1.-z).*h .+ z.*h̃
|
||||
return h, h
|
||||
b, o = m.b, size(h, 1)
|
||||
gx, gh = m.Wi*x, m.Wh*h
|
||||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
||||
h′ = (1.-z).*h̃ .+ z.*h
|
||||
return h′, h′
|
||||
end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
treelike(GRUCell)
|
||||
|
||||
Base.show(io::IO, m::GRUCell) =
|
||||
print(io, "GRUCell(",
|
||||
size(m.update.W, 2) - size(m.update.W, 1), ", ",
|
||||
size(m.update.W, 1), ')')
|
||||
Base.show(io::IO, l::GRUCell) =
|
||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")")
|
||||
|
||||
"""
|
||||
GRU(in::Integer, out::Integer, σ = tanh)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Juno
|
||||
using Flux.Tracker: back!, value
|
||||
using Flux.Tracker: back!
|
||||
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(value(l)) && error("Loss is Inf")
|
||||
isnan(value(l)) && error("Loss is NaN")
|
||||
isinf(l) && error("Loss is Inf")
|
||||
isnan(l) && error("Loss is NaN")
|
||||
back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
|
|
|
@ -2,8 +2,12 @@ module Tracker
|
|||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||
|
||||
data(x) = x
|
||||
istracked(x) = false
|
||||
tracker(x) = nothing
|
||||
|
||||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||
grad(x) = grad(tracker(x))
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
|
@ -14,109 +18,39 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
|
||||
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
|
||||
isleaf::Bool
|
||||
data::T
|
||||
grad::T
|
||||
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
||||
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
||||
end
|
||||
|
||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x)
|
||||
track(f::Call, x) = Tracked(f, x)
|
||||
track(f::Call) = track(f, f())
|
||||
track(f, xs...) = track(Call(f, xs...))
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||
|
||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
isleaf(x::TrackedArray) = x.f == Call(nothing)
|
||||
|
||||
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
|
||||
param(xs::Real) = param(fill(xs))
|
||||
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
# TODO decide if keeping both data and value. The problem is TrackedScalar
|
||||
value(x) = x
|
||||
value(x::TrackedArray) = data(x)
|
||||
value(x::TrackedScalar) = data(x)[]
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
|
||||
|
||||
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
||||
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
||||
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
||||
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
end
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
istracked(x::Tracked) = true
|
||||
isleaf(x::Tracked) = x.f == Call(nothing)
|
||||
data(x::Tracked) = x.data
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
include("back.jl")
|
||||
include("lib.jl")
|
||||
include("scalar.jl")
|
||||
include("array.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
isleaf(x) ? constant(x) :
|
||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
end
|
||||
|
||||
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
|
||||
|
||||
function graph(f, args...)
|
||||
inputs = param.(args)
|
||||
_graph(f(inputs...), inputs...)
|
||||
end
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
||||
|
|
|
@ -1,45 +1,100 @@
|
|||
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||
end
|
||||
|
||||
unarray(xs) = xs
|
||||
unarray(xs::AbstractArray{T,0} where T) = xs[]
|
||||
tracker(x::TrackedArray) = x.tracker
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) =
|
||||
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
end
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
back!(::TrackedArray) = error("Use back!(x, Δ)")
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs.data)
|
||||
Δ′[i...] = unarray(Δ)
|
||||
Δ′[i...] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
|
||||
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
|
||||
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
|
||||
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
||||
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector...) = TrackedArray(Call(vcat, a, b...))
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = TrackedArray(Call(vcat, a, b...))
|
||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = TrackedArray(Call(vcat, a, b...))
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
i = Base.tail(map(_ -> :, size(Δ)))
|
||||
|
@ -51,32 +106,32 @@ function back(::typeof(vcat), Δ, xs...)
|
|||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||
TrackedArray(Call(reshape, xs, dims...))
|
||||
track(reshape, xs, dims...)
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
||||
Base.sum(xs::TrackedScalar, dim...) = xs
|
||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
function back(::typeof(dot), Δ, xs, ys)
|
||||
@back(xs, Δ.*ys)
|
||||
@back(ys, Δ.*xs)
|
||||
@back(xs, Δ.*data(ys))
|
||||
@back(ys, Δ.*data(xs))
|
||||
end
|
||||
|
||||
# Hacks to get std working
|
||||
|
@ -85,29 +140,34 @@ Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
|||
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
|
||||
Base.norm(x::TrackedArray, p::Real = 2) =
|
||||
p == 1 ? sum(abs.(x)) :
|
||||
p == 2 ? sqrt(sum(abs2.(x))) :
|
||||
error("$p-norm not supported")
|
||||
|
||||
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
|
||||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = TrackedArray(Call(diagm, x))
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
|
||||
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
|
||||
$f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
||||
|
||||
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -141,11 +201,11 @@ end
|
|||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, logσ, ∇logσ, conv2d, pool
|
||||
|
||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||
|
||||
logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
|
||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||
|
||||
|
@ -157,11 +217,11 @@ back(::typeof(logσ), Δ, xs) = @back(xs, ∇logσ(Δ, data(xs)))
|
|||
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
|
||||
|
||||
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
track(_conv2d, x, w, stride, padding)
|
||||
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
track(_conv2d, x, w, stride, padding)
|
||||
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
track(_conv2d, x, w, stride, padding)
|
||||
|
||||
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
|
||||
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||
|
@ -171,7 +231,7 @@ end
|
|||
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
|
||||
|
||||
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
|
||||
TrackedArray(Call(_pool, x, window, padding, mode))
|
||||
track(_pool, x, window, padding, mode)
|
||||
|
||||
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
|
||||
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
|
||||
|
@ -189,23 +249,24 @@ end
|
|||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
||||
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
|
||||
|
||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
# TrackedArray(Call(Broadcasted(f, broadcast(f, dargs...)), args...))
|
||||
# Works around a 0.6 type inference issue
|
||||
b = Broadcasted(f, out)
|
||||
TrackedArray(Call(b, args...), b())
|
||||
track(Call(b, args...), b())
|
||||
end
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
||||
|
||||
unbroadcast(x, Δ) =
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
|
||||
function getpartial(Δ, x, i)
|
||||
@inbounds p = getindex(partials(x), i)
|
||||
return Δ * p
|
|
@ -1,25 +1,38 @@
|
|||
scan(x) = nothing
|
||||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::TrackedArray)
|
||||
function scan(x::Tracked)
|
||||
x.isleaf && return
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||
else
|
||||
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
function scan(x)
|
||||
istracked(x) && scan(tracker(x))
|
||||
return
|
||||
end
|
||||
|
||||
back_(f, y, args...) = back(f, args...)
|
||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||
back_(::Call{Void}, y, Δ) = nothing
|
||||
|
||||
function back(x::TrackedArray, Δ)
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ)
|
||||
x.isleaf && (accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad .+= Δ
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
ref == 0 && back_(x.f, x.data, x.grad)
|
||||
else
|
||||
ref == 0 && back_(x.f, x.data, Δ)
|
||||
|
@ -27,6 +40,9 @@ function back(x::TrackedArray, Δ)
|
|||
return
|
||||
end
|
||||
|
||||
back(x, Δ) = back(tracker(x), Δ)
|
||||
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
|
||||
macro back(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
|
@ -39,9 +55,9 @@ end
|
|||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
|
||||
function back!(x::TrackedArray, Δ)
|
||||
function back!(x::Tracked, Δ)
|
||||
scan(x)
|
||||
back(x, Δ)
|
||||
end
|
||||
|
||||
back!(x::TrackedScalar) = back!(x, 1)
|
||||
back!(x, Δ) = back!(tracker(x), Δ)
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
struct TrackedReal{T<:Real} <: Real
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
|
||||
|
||||
tracker(x::TrackedReal) = x.tracker
|
||||
|
||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||
|
||||
back!(x::TrackedReal) = back!(x, 1)
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
show(io, data(x))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||
|
||||
for f in :[isinf, isnan, isfinite].args
|
||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||
end
|
||||
|
||||
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
|
||||
|
||||
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
||||
TrackedReal{promote_type(S,T)}
|
||||
|
||||
using DiffRules, SpecialFunctions, NaNMath
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 1 || continue
|
||||
@eval begin
|
||||
$M.$f(a::TrackedReal) = track($M.$f, a)
|
||||
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
|
||||
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
|
||||
end
|
||||
end
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 2 || continue
|
||||
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
||||
@eval begin
|
||||
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
|
||||
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
|
||||
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
|
||||
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
|
||||
@back(a, Δ * $da)
|
||||
@back(b, Δ * $db)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
tracker(xs::TrackedTuple) = xs.tracker
|
||||
|
||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
|
||||
back(::typeof(getindex), Δ, t, i) =
|
||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
|
@ -35,7 +35,10 @@ end
|
|||
|
||||
function params(m)
|
||||
ps = []
|
||||
prefor(p -> p isa TrackedArray && push!(ps, p), m)
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
!(p in ps) && push!(ps, p),
|
||||
m)
|
||||
return ps
|
||||
end
|
||||
|
||||
|
|
11
src/utils.jl
11
src/utils.jl
|
@ -72,17 +72,6 @@ end
|
|||
|
||||
# Other
|
||||
|
||||
function accuracy(m, data)
|
||||
n = 0
|
||||
correct = 0
|
||||
for (x, y) in data
|
||||
x, y = tobatch.((x, y))
|
||||
n += size(x, 1)
|
||||
correct += sum(argmax(m(x)) .== argmax(y))
|
||||
end
|
||||
return correct/n
|
||||
end
|
||||
|
||||
"""
|
||||
Returns a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds. Normally, the throttled function will run
|
||||
|
|
|
@ -21,3 +21,5 @@ cm = cu(m)
|
|||
@test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
end
|
||||
|
||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
|
@ -0,0 +1,31 @@
|
|||
using Flux, CuArrays, Base.Test
|
||||
|
||||
info("Testing Flux/CUDNN")
|
||||
|
||||
@testset "RNN" begin
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
x = param(rand(10,5))
|
||||
cux = cu(x)
|
||||
rnn = R(10, 5)
|
||||
curnn = mapleaves(cu, rnn)
|
||||
y = (rnn(x); rnn(x))
|
||||
cuy = (curnn(cux); curnn(cux))
|
||||
|
||||
@test y.data ≈ collect(cuy.data)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
|
||||
Δ = randn(size(y))
|
||||
|
||||
Flux.back!(y, Δ)
|
||||
Flux.back!(cuy, cu(Δ))
|
||||
|
||||
@test x.grad ≈ collect(cux.grad)
|
||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
||||
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
||||
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
||||
if isdefined(rnn.cell, :c)
|
||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,7 +3,7 @@ using Flux.Tracker
|
|||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt([w′])
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
using Flux, Base.Test
|
||||
|
||||
srand(0)
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
include("utils.jl")
|
||||
|
@ -10,7 +12,7 @@ include("optimise.jl")
|
|||
include("data.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuarrays.jl")
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: gradcheck
|
||||
using Flux.Tracker: TrackedReal, gradcheck
|
||||
using NNlib
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
|
@ -15,7 +15,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
|
||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
@ -47,6 +47,7 @@ end
|
|||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
|
@ -59,4 +60,24 @@ end
|
|||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
|
||||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
Flux.back!(l)
|
||||
@test x.grad == [8]
|
||||
x.grad .= 0
|
||||
Flux.back!(l)
|
||||
@test x.grad == [8]
|
||||
end
|
||||
|
||||
@testset "Fallbacks" begin
|
||||
xs = param([1 2; 3 4])
|
||||
@test similar(xs) isa Matrix{Float64}
|
||||
# Remove this test if we do LowerTriangular properly
|
||||
L = LowerTriangular(xs)
|
||||
@test L*L' isa Matrix{TrackedReal{Float64}}
|
||||
end
|
||||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
end #testset
|
||||
|
|
|
@ -79,3 +79,10 @@ end
|
|||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Params" begin
|
||||
m = Dense(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5,)]
|
||||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue