Merge branch 'master' into HEAD
This commit is contained in:
commit
e6be639436
|
@ -5,3 +5,4 @@ docs/build/
|
|||
docs/site/
|
||||
docs/flux.css
|
||||
deps
|
||||
Manifest.toml
|
||||
|
|
13
.travis.yml
13
.travis.yml
|
@ -4,11 +4,16 @@ os:
|
|||
- linux
|
||||
# - osx
|
||||
julia:
|
||||
- 0.6
|
||||
- 0.7
|
||||
- 1.0
|
||||
- nightly
|
||||
# uncomment the following lines to override the default test script
|
||||
script:
|
||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
# script:
|
||||
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
matrix:
|
||||
allow_failures:
|
||||
- julia: nightly
|
||||
after_success:
|
||||
- julia -e 'Pkg.add("Documenter")'
|
||||
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||
|
|
5
REQUIRE
5
REQUIRE
|
@ -1,14 +1,15 @@
|
|||
julia 0.6.0
|
||||
julia 0.7
|
||||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
Requires
|
||||
Adapt
|
||||
GZip
|
||||
CodecZlib
|
||||
Colors
|
||||
ZipFile
|
||||
AbstractTrees
|
||||
Reexport
|
||||
StatsBase
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
||||
|
||||
```
|
||||
julia> using Flux: onehot
|
||||
julia> using Flux: onehot, onecold
|
||||
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
|
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
|
|||
true
|
||||
```
|
||||
|
||||
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
|
||||
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||
|
||||
```julia
|
||||
julia> argmax(ans, [:a, :b, :c])
|
||||
julia> onecold(ans, [:a, :b, :c])
|
||||
:c
|
||||
|
||||
julia> argmax([true, false, false], [:a, :b, :c])
|
||||
julia> onecold([true, false, false], [:a, :b, :c])
|
||||
:a
|
||||
|
||||
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
:c
|
||||
```
|
||||
|
||||
## Batches
|
||||
|
||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
|
||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||
|
||||
```julia
|
||||
julia> using Flux: onehotbatch
|
||||
|
|
|
@ -134,7 +134,7 @@ All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around
|
|||
|
||||
```julia
|
||||
julia> x.tracker
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
```
|
||||
|
||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
||||
|
|
|
@ -211,7 +211,7 @@ m(5) # => 26
|
|||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||
|
||||
```julia
|
||||
Flux.treelike(Affine)
|
||||
Flux.@treelike Affine
|
||||
```
|
||||
|
||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
__precompile__()
|
||||
|
||||
module Flux
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
using Juno, Requires, Reexport
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
|
@ -12,7 +10,6 @@ export Chain, Dense, RNN, LSTM, GRU, Conv,
|
|||
params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
using NNlib: @fix
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
|
@ -37,6 +34,6 @@ include("layers/normalise.jl")
|
|||
|
||||
include("data/Data.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
module CUDA
|
||||
|
||||
using CuArrays
|
||||
using ..CuArrays
|
||||
|
||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||
|
||||
|
|
|
@ -1,24 +1,27 @@
|
|||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||
cudnnDataType, TensorDesc, FilterDesc
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Void}
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
|
||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
|
||||
|
||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||
d = [C_NULL]
|
||||
s = Csize_t[0]
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||
finalizer(desc, x ->
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
finalizer(desc) do x
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return desc
|
||||
end
|
||||
|
||||
|
@ -43,10 +46,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
||||
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
||||
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
||||
(wx, wh), bias
|
||||
end
|
||||
|
||||
|
@ -57,14 +60,14 @@ mutable struct RNNDesc{T}
|
|||
params::CuVector{T}
|
||||
weights::NTuple{2,CuMatrix{T}}
|
||||
bias::CuVector{T}
|
||||
ptr::Ptr{Void}
|
||||
ptr::Ptr{Nothing}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
||||
|
||||
function rnnParamSize(T, r, input)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
|
||||
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||
return Int(size[])÷sizeof(T)
|
||||
end
|
||||
|
@ -74,26 +77,27 @@ ngates(r::RNNDesc) = ngates(r.mode)
|
|||
|
||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
|
||||
|
||||
dropoutDesc = DropoutDesc(0)
|
||||
inputMode = LINEAR_INPUT
|
||||
direction = UNIDIRECTIONAL
|
||||
algo = RNN_ALGO_STANDARD
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
finalizer(rd) do x
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return rd
|
||||
end
|
||||
|
||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
|
||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
@ -110,7 +114,7 @@ getworkspace(r::RNNDesc, seqlen, xdesc) =
|
|||
|
||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
|
||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||
return Int(size[])
|
||||
end
|
||||
|
@ -119,19 +123,19 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
|
|||
workspace, reserve=nothing) where T
|
||||
if reserve == nothing
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t),
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace))
|
||||
else
|
||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
|
@ -140,7 +144,7 @@ end
|
|||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Void) = C_NULL, C_NULL
|
||||
hDesc(h::Nothing) = C_NULL, C_NULL
|
||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
|
@ -187,18 +191,18 @@ forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T
|
|||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
|
||||
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
||||
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
# Same as above, any more efficient way?
|
||||
dy = dy_ isa Integer ? zeros(y) : dy_
|
||||
dy = dy_ isa Integer ? zero(y) : dy_
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
|
@ -217,19 +221,19 @@ backwardData(rnn, y, dy, dho, hx, reserve) =
|
|||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #x
|
||||
Ptr{Void}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #y
|
||||
Ptr{Void}, Csize_t, #ws
|
||||
Ptr{Void}, Ptr{T}, #dw
|
||||
Ptr{Void}, Csize_t), #rs
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
||||
Ptr{Nothing}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
||||
Ptr{Nothing}, Csize_t, #ws
|
||||
Ptr{Nothing}, Ptr{T}, #dw
|
||||
Ptr{Nothing}, Csize_t), #rs
|
||||
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zeros(rnn.params)
|
||||
dw = zero(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
|
@ -241,17 +245,17 @@ end
|
|||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
using .CuArrays.CUDAnative
|
||||
using .CuArrays: @cuindex, cudims
|
||||
|
||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
|
@ -324,7 +328,7 @@ end
|
|||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -339,6 +343,6 @@ end
|
|||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||
dWi.', dWh.', db))
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
|
|
@ -11,6 +11,8 @@ function __init__()
|
|||
end
|
||||
|
||||
include("mnist.jl")
|
||||
export MNIST
|
||||
|
||||
include("cmudict.jl")
|
||||
using .CMUDict
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ function load()
|
|||
return
|
||||
end
|
||||
end
|
||||
info("Downloading CMUDict dataset")
|
||||
@info "Downloading CMUDict dataset"
|
||||
mkpath(deps("cmudict"))
|
||||
for x in suffixes
|
||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||
|
@ -24,25 +24,25 @@ end
|
|||
|
||||
function phones()
|
||||
load()
|
||||
Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")),
|
||||
"\n", keep = false), "\t")))
|
||||
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
||||
"\n", keepempty = false), "\t")))
|
||||
end
|
||||
|
||||
function symbols()
|
||||
load()
|
||||
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
|
||||
"\n", keep = false))
|
||||
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
||||
"\n", keepempty = false))
|
||||
end
|
||||
|
||||
function rawdict()
|
||||
load()
|
||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
|
||||
filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n"))))
|
||||
end
|
||||
|
||||
validword(s) = ismatch(r"^[\w\-\.]+$", s)
|
||||
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
||||
|
||||
cmudict() = filter((s, ps) -> validword(s), rawdict())
|
||||
cmudict() = filter(p -> validword(p.first), rawdict())
|
||||
|
||||
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
||||
|
||||
|
|
|
@ -1,11 +1,17 @@
|
|||
module MNIST
|
||||
|
||||
using GZip, Colors
|
||||
using CodecZlib, Colors
|
||||
|
||||
const Gray = Colors.Gray{Colors.N0f8}
|
||||
|
||||
const dir = joinpath(@__DIR__, "../../deps/mnist")
|
||||
|
||||
function gzopen(f, file)
|
||||
open(file) do io
|
||||
f(GzipDecompressorStream(io))
|
||||
end
|
||||
end
|
||||
|
||||
function load()
|
||||
mkpath(dir)
|
||||
cd(dir) do
|
||||
|
@ -14,10 +20,10 @@ function load()
|
|||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
isfile(file) && continue
|
||||
info("Downloading MNIST dataset")
|
||||
@info "Downloading MNIST dataset"
|
||||
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||
open(file, "w") do io
|
||||
write(io, GZip.open(read, "$file.gz"))
|
||||
write(io, gzopen(read, "$file.gz"))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -49,7 +55,7 @@ function labelheader(io::IO)
|
|||
end
|
||||
|
||||
function rawimage(io::IO)
|
||||
img = Array{Gray}(NCOLS, NROWS)
|
||||
img = Array{Gray}(undef, NCOLS, NROWS)
|
||||
for i in 1:NCOLS, j in 1:NROWS
|
||||
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
||||
end
|
||||
|
|
|
@ -5,7 +5,7 @@ using ..Data: deps
|
|||
|
||||
function load()
|
||||
isfile(deps("sentiment.zip")) || return
|
||||
info("Downloading sentiment treebank dataset")
|
||||
@info "Downloading sentiment treebank dataset"
|
||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
end
|
||||
|
@ -14,7 +14,7 @@ getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
|||
|
||||
function getfile(name)
|
||||
r = ZipFile.Reader(deps("sentiment.zip"))
|
||||
text = readstring(getfile(r, "trees/$name"))
|
||||
text = read(getfile(r, "trees/$name"), String)
|
||||
close(r)
|
||||
return text
|
||||
end
|
||||
|
@ -29,12 +29,12 @@ function parsetree(s)
|
|||
s = replace(s, r"\$", s -> "\\\$")
|
||||
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
||||
s = replace(s, " ", ", ")
|
||||
return totree(parse(s))
|
||||
return totree(Meta.parse(s))
|
||||
end
|
||||
|
||||
function gettrees(name)
|
||||
load()
|
||||
ss = split(getfile("$name.txt"), '\n', keep = false)
|
||||
ss = split(getfile("$name.txt"), '\n', keepempty = false)
|
||||
return parsetree.(ss)
|
||||
end
|
||||
|
||||
|
|
|
@ -16,19 +16,19 @@ m(x) == m[2](m[1](x))
|
|||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||
`m[1:3](x)` will calculate the output of the first three layers.
|
||||
"""
|
||||
type Chain
|
||||
struct Chain
|
||||
layers::Vector{Any}
|
||||
Chain(xs...) = new([xs...])
|
||||
end
|
||||
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
|
||||
@forward Chain.layers Base.iterate
|
||||
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
|
@ -38,10 +38,7 @@ function Base.show(io::IO, c::Chain)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
# Seem to need this for `accumulate`; try removing on 0.7
|
||||
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
@ -76,11 +73,11 @@ function Dense(in::Integer, out::Integer, σ = identity;
|
|||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||
end
|
||||
|
||||
treelike(Dense)
|
||||
@treelike Dense
|
||||
|
||||
function (a::Dense)(x)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
@fix σ.(W*x .+ b)
|
||||
σ.(W*x .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Dense)
|
||||
|
@ -107,7 +104,7 @@ end
|
|||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
|
||||
treelike(Diagonal)
|
||||
@treelike Diagonal
|
||||
|
||||
function (a::Diagonal)(x)
|
||||
α, β = a.α, a.β
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using NNlib: conv
|
||||
|
||||
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
||||
@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2)))
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
@ -35,7 +35,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
|
|||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
Flux.treelike(Conv)
|
||||
@treelike Conv
|
||||
|
||||
function (c::Conv)(x)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
|
|
|
@ -58,7 +58,7 @@ end
|
|||
LayerNorm(h::Integer) =
|
||||
LayerNorm(Diagonal(h))
|
||||
|
||||
treelike(LayerNorm)
|
||||
@treelike LayerNorm
|
||||
|
||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||
|
||||
|
@ -108,7 +108,7 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
||||
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
|
@ -130,13 +130,13 @@ function (BN::BatchNorm)(x)
|
|||
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, axes)
|
||||
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
|
||||
μ = mean(x, dims = axes)
|
||||
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
gate(h, n) = (1:h) + h*(n-1)
|
||||
gate(h, n) = (1:h) .+ h*(n-1)
|
||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||
|
||||
|
@ -38,7 +38,7 @@ function (m::Recur)(xs...)
|
|||
return y
|
||||
end
|
||||
|
||||
treelike(Recur, (:cell, :init))
|
||||
@treelike Recur cell, init
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
|
@ -84,7 +84,7 @@ end
|
|||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zeros(out)), param(initn(out)))
|
||||
param(zeros(out)), param(init(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
|
@ -94,7 +94,7 @@ end
|
|||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
treelike(RNNCell)
|
||||
@treelike RNNCell
|
||||
|
||||
function Base.show(io::IO, l::RNNCell)
|
||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||
|
@ -123,8 +123,8 @@ end
|
|||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.b.data[gate(out, 2)] = 1
|
||||
param(init(out)), param(init(out)))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
|
||||
|
@ -143,7 +143,7 @@ end
|
|||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
treelike(LSTMCell)
|
||||
@treelike LSTMCell
|
||||
|
||||
Base.show(io::IO, l::LSTMCell) =
|
||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||
|
@ -170,7 +170,7 @@ end
|
|||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(initn(out)))
|
||||
param(zeros(out*3)), param(init(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
@ -178,13 +178,13 @@ function (m::GRUCell)(h, x)
|
|||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
||||
h′ = (1.-z).*h̃ .+ z.*h
|
||||
h′ = (1 .- z).*h̃ .+ z.*h
|
||||
return h′, h′
|
||||
end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
treelike(GRUCell)
|
||||
@treelike GRUCell
|
||||
|
||||
Base.show(io::IO, l::GRUCell) =
|
||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||
|
|
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
|
|
@ -32,20 +32,21 @@ import Adapt.adapt
|
|||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
@require CuArrays begin
|
||||
import CuArrays: CuArray, cudaconvert
|
||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
function onehot(l, labels)
|
||||
i = findfirst(labels, l)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
function onehot(l, labels, unk)
|
||||
i = findfirst(labels, l)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || return onehot(unk, labels)
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
@ -53,11 +54,15 @@ end
|
|||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
|
||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[findfirst(y, maximum(y))]
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
argmax(y::AbstractMatrix, l...) =
|
||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
||||
onecold(y::AbstractMatrix, labels...) =
|
||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||
|
||||
function argmax(xs...)
|
||||
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
return onecold(xs...)
|
||||
end
|
||||
|
||||
# Ambiguity hack
|
||||
|
||||
|
|
|
@ -2,14 +2,14 @@ module Optimise
|
|||
|
||||
export train!,
|
||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δ::T
|
||||
end
|
||||
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||
Param(x::AbstractArray) = Param(x, zero(x))
|
||||
|
||||
include("optimisers.jl")
|
||||
include("interface.jl")
|
||||
|
@ -17,6 +17,7 @@ include("train.jl")
|
|||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
Param(x::TrackedArray) = Param(x.data, x.grad)
|
||||
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
|
||||
end
|
||||
|
|
|
@ -14,7 +14,7 @@ function descentweightdecay(p::Param, η::Real, γ::Real)
|
|||
end
|
||||
|
||||
function momentum(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
@. v = ρ * v - η * p.Δ
|
||||
@. p.Δ = -v
|
||||
|
@ -23,7 +23,7 @@ end
|
|||
|
||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
||||
function nesterov(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
||||
@. v = ρ*v - η*p.Δ
|
||||
|
@ -32,7 +32,7 @@ function nesterov(p::Param, ρ, η)
|
|||
end
|
||||
|
||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x)
|
||||
acc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
|
@ -40,7 +40,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x) .+ ϵ
|
||||
acc = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. acc += p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
|
@ -48,8 +48,8 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x)
|
||||
Δacc = zeros(p.x)
|
||||
acc = zero(p.x)
|
||||
Δacc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||
|
@ -58,8 +58,8 @@ function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
@ -71,8 +71,8 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||
end
|
||||
|
||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
ut = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
ut = zero(p.x)
|
||||
β1p = β1
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
@ -83,9 +83,9 @@ function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999,
|
|||
end
|
||||
|
||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x) .+ ϵ
|
||||
v̂t = zeros(p.x) .+ ϵ
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x) .+ ϵ
|
||||
v̂t = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||
|
@ -95,8 +95,8 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||
end
|
||||
|
||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using Juno
|
||||
using Flux.Tracker: back!
|
||||
import Base.depwarn
|
||||
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
@ -14,6 +15,25 @@ macro interrupts(ex)
|
|||
end)
|
||||
end
|
||||
|
||||
struct StopException <: Exception end
|
||||
"""
|
||||
stop()
|
||||
|
||||
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||
This would trigger the train loop to stop and exit.
|
||||
|
||||
```julia
|
||||
# Example callback:
|
||||
|
||||
cb = function ()
|
||||
accuracy() > 0.9 && Flux.stop()
|
||||
end
|
||||
```
|
||||
"""
|
||||
function stop()
|
||||
throw(StopException())
|
||||
end
|
||||
|
||||
"""
|
||||
train!(loss, data, opt)
|
||||
|
||||
|
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
end
|
||||
catch ex
|
||||
if ex isa StopException
|
||||
break
|
||||
else
|
||||
rethrow(ex)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -59,7 +90,7 @@ hello
|
|||
"""
|
||||
macro epochs(n, ex)
|
||||
:(@progress for i = 1:$(esc(n))
|
||||
info("Epoch $i")
|
||||
@info "Epoch $i"
|
||||
$(esc(ex))
|
||||
end)
|
||||
end
|
||||
|
|
|
@ -12,7 +12,7 @@ tracker(x) = nothing
|
|||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
grad(x) = grad(tracker(x))
|
||||
grad(::Void) = nothing
|
||||
grad(::Nothing) = nothing
|
||||
data(x) = x
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
|
@ -35,7 +35,7 @@ mutable struct Tracked{T}
|
|||
grad::T
|
||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
|
||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||
end
|
||||
|
||||
istracked(x::Tracked) = true
|
||||
|
@ -46,14 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
|||
|
||||
function _forward end
|
||||
|
||||
function track(f::F, xs...) where F
|
||||
y, back = _forward(f, xs...)
|
||||
ts = map(tracker, xs)
|
||||
c = Call(back, ts)
|
||||
track(c, y)
|
||||
end
|
||||
|
||||
function track_kw(f::F, xs...; kw...) where F
|
||||
function track(f::F, xs...; kw...) where F
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
@ -84,10 +77,9 @@ include("numeric.jl")
|
|||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
the sign of the gradient applied to `x`."""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||
|
||||
"""
|
||||
checkpoint(f, args...)
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
import Base: *, ==
|
||||
|
||||
import LinearAlgebra
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
data::A
|
||||
|
@ -21,24 +27,20 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
function Base.summary(io::IO, x::TrackedArray)
|
||||
print(io, "Tracked ")
|
||||
summary(io, data(x))
|
||||
end
|
||||
|
||||
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
|
@ -46,7 +48,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
|
|||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
|
@ -58,9 +60,9 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
||||
x::TrackedArray == y = data(x) == y
|
||||
y == x::TrackedArray = y == data(x)
|
||||
x::TrackedArray == y::TrackedArray = data(x) == data(y)
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
|
@ -79,37 +81,20 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
|||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
|
||||
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
|
||||
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||
|
||||
@grad function repmat(xs, m, n = 1)
|
||||
repmat(data(xs), m, n), function (Δ)
|
||||
Δ′ = similar(xs)
|
||||
S = size(xs)
|
||||
for (i,v) in enumerate(data(Δ))
|
||||
d1 = divrem(i-1, S[1]*m)
|
||||
x = d1[2] % S[1]+1
|
||||
y = d1[1] % S[2]+1
|
||||
Δ′[x, y] += v
|
||||
end
|
||||
return (nobacksies(:repmat, Δ′), nothing, nothing)
|
||||
end
|
||||
end
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
S = size(xs)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
|
||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
|
@ -119,8 +104,8 @@ Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
|||
end
|
||||
end
|
||||
|
||||
|
||||
for f in [:vcat, :hcat]
|
||||
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||
@eval begin
|
||||
# This section is a bit of a hack since julia doesn't have a standardised
|
||||
# promotion mechanism for concatenation yet
|
||||
|
@ -129,18 +114,18 @@ for f in [:vcat, :hcat]
|
|||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||
# TrackedArray anywhere among the arguments This works as long as base has
|
||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
||||
Base.$f(a::$UArray...) = track($f, a...)
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# first
|
||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# second
|
||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
||||
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
||||
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
||||
c::$UArray...) =
|
||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||
end
|
||||
end
|
||||
|
@ -175,21 +160,23 @@ end
|
|||
end
|
||||
end
|
||||
|
||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
|
||||
@grad function cat(dims, Xs...)
|
||||
cat(dims, data.(Xs)...), function (Δ)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
start = ntuple(i -> 0, Val(ndims(Δ)))
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||
start = start .+ till_xs
|
||||
d
|
||||
end for xs in Xs]
|
||||
return (nothing, Δs...,)
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -218,98 +205,86 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, )
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
||||
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
||||
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
import LinearAlgebra: dot
|
||||
|
||||
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||
|
||||
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
||||
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
||||
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||
|
||||
@grad function maximum(xs, r...)
|
||||
maximum(data(xs), r...), function (Δ)
|
||||
@grad function maximum(xs; dims = dims)
|
||||
maximum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), r...)
|
||||
_, i = findmax(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:maximum, Δ′),)
|
||||
end
|
||||
end
|
||||
@grad function minimum(xs, r...)
|
||||
minimum(data(xs), r...), function (Δ)
|
||||
|
||||
@grad function minimum(xs; dims = dims)
|
||||
minimum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), r...)
|
||||
_, i = findmin(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:minimum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
|
||||
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
||||
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
||||
|
||||
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
||||
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
||||
|
||||
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
# NNlib
|
||||
|
||||
|
@ -324,9 +299,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
|
@ -334,14 +309,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
|||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
|
@ -352,45 +327,85 @@ end
|
|||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::Tuple, Δ) =
|
||||
x == size(Δ) ? Δ :
|
||||
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
||||
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
||||
|
||||
function getpartial(Δ, x, i)
|
||||
@inbounds p = getindex(partials(x), i)
|
||||
return Δ * p
|
||||
dual(x, p) = x
|
||||
dual(x::Real, p) = Dual(x, p)
|
||||
|
||||
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||
return Δ * f(dargs...).partials[1]
|
||||
end
|
||||
|
||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
sizes = size.(args)
|
||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
y = value.(out)
|
||||
back = function (Δ_)
|
||||
Δ = data(Δ_)
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||
nobacksies(:broadcast, dxs)
|
||||
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||
y = broadcast(f, data.(args)...)
|
||||
eltype(y) <: Real || return y
|
||||
eltype(y) == Bool && return y
|
||||
function back(Δ)
|
||||
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
||||
dxs = unbroadcast.(args, Δargs)
|
||||
return nobacksies(:broadcast, dxs)
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||
|
||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
|
||||
struct TrackedStyle <: BroadcastStyle end
|
||||
|
||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||
|
||||
# We have to re-build the original broadcast struct to get the appropriate array
|
||||
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||
broadcast_rebuild(xs) = data(xs)
|
||||
|
||||
broadcast_rebuild(bc::Broadcasted) =
|
||||
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||
|
||||
preprocess(x) = x
|
||||
|
||||
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||
bc1 = Broadcast.flatten(bc)
|
||||
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||
∇broadcast(bc2.f, bc1.args...)
|
||||
end
|
||||
|
||||
using Requires
|
||||
|
||||
# https://github.com/FluxML/Flux.jl/issues/353
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
isflat(bc) && return bc
|
||||
args = cat_nested(bc)
|
||||
let makeargs = make_makeargs(bc), f = bc.f
|
||||
newf = @inline function(args::Vararg{Any,N}) where N
|
||||
f(makeargs(args...)...)
|
||||
end
|
||||
return Broadcasted{Style}(newf, args, bc.axes)
|
||||
end
|
||||
end
|
||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||
bc = t[1]
|
||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||
let makeargs = make_makeargs(makeargs, bc.args)
|
||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||
return @inline function(args::Vararg{Any,N}) where N
|
||||
args1 = makeargs(args...)
|
||||
a, b = headargs(args1...), tailargs(args1...)
|
||||
(f(a...), b...)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -26,7 +26,7 @@ function back_(c::Call, Δ)
|
|||
foreach(back, c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Void}, Δ) = nothing
|
||||
back_(::Call{Nothing}, Δ) = nothing
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
@ -47,7 +47,7 @@ function back(x::Tracked, Δ)
|
|||
return
|
||||
end
|
||||
|
||||
back(::Void, _) = return
|
||||
back(::Nothing, _) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
|
@ -70,7 +70,7 @@ struct Params
|
|||
Params(xs) = new(IdSet(xs))
|
||||
end
|
||||
|
||||
@forward Params.params Base.start, Base.next, Base.done
|
||||
@forward Params.params Base.iterate, Base.length
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
|
@ -79,14 +79,16 @@ function Base.show(io::IO, ps::Params)
|
|||
end
|
||||
|
||||
struct Grads
|
||||
grads::ObjectIdDict
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(ObjectIdDict())
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
function Base.getindex(g::Grads, x)
|
||||
|
@ -94,9 +96,8 @@ function Base.getindex(g::Grads, x)
|
|||
g[tracker(x)]
|
||||
end
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
|
@ -105,7 +106,7 @@ function back_(g::Grads, c::Call, Δ)
|
|||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
||||
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
x.isleaf && (accum!(g, x, Δ); return)
|
||||
|
@ -119,7 +120,7 @@ function back(g::Grads, x::Tracked, Δ)
|
|||
return
|
||||
end
|
||||
|
||||
back(::Grads, ::Void, _) = return
|
||||
back(::Grads, ::Nothing, _) = return
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
|
@ -136,7 +137,7 @@ end
|
|||
function forward(f, args...)
|
||||
args = param.(args)
|
||||
y, back = forward(() -> f(args...), Params(args))
|
||||
y, Δ -> getindex.(back(Δ), args)
|
||||
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||
end
|
||||
|
||||
function losscheck(x)
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
struct IdSet{T} <: AbstractSet{T}
|
||||
dict::ObjectIdDict
|
||||
IdSet{T}() where T = new(ObjectIdDict())
|
||||
dict::IdDict{T,Nothing}
|
||||
IdSet{T}() where T = new(IdDict{T,Nothing}())
|
||||
end
|
||||
|
||||
Base.eltype{T}(::IdSet{T}) = T
|
||||
Base.eltype(::IdSet{T}) where T = T
|
||||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
|
||||
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
|
||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
||||
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
|
||||
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
|
||||
|
||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||
|
||||
|
@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
|||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
Base.start(s::IdSet) = start(keys(s.dict))
|
||||
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
||||
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
||||
function Base.iterate(v::IdSet, state...)
|
||||
y = Base.iterate(keys(v.dict), state...)
|
||||
y === nothing && return nothing
|
||||
return (y[1], y[2])
|
||||
end
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zeros.(xs)
|
||||
grads = zero.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp = x[i]
|
||||
|
|
|
@ -115,3 +115,7 @@ end
|
|||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach(back, c.args[1], data(Δ))
|
||||
end
|
||||
|
||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||
end
|
||||
|
|
|
@ -7,16 +7,27 @@ mapchildren(f, x) = x
|
|||
children(x::Tuple) = x
|
||||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
@eval current_module() begin
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||
end
|
||||
end
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
|
||||
treelike(Base._current_module(), T, fs)
|
||||
end
|
||||
|
||||
macro treelike(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
function mapleaves(f, x; cache = ObjectIdDict())
|
||||
function mapleaves(f, x; cache = IdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
@ -53,7 +64,7 @@ cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
|||
|
||||
gpu_adaptor = identity
|
||||
|
||||
@require CuArrays begin
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
global gpu_adaptor = CuArrays.cu
|
||||
end
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
|
@ -119,7 +119,7 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||
end
|
||||
|
||||
cooldown = false
|
||||
@schedule try
|
||||
@async try
|
||||
while (sleep(timeout); later != nothing)
|
||||
later()
|
||||
later = nothing
|
||||
|
@ -145,7 +145,7 @@ function jacobian(m,x)
|
|||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(n,k)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i]) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux: gpu
|
||||
|
||||
info("Testing Flux/GPU")
|
||||
@info "Testing Flux/GPU"
|
||||
|
||||
@testset "CuArrays" begin
|
||||
|
||||
|
@ -14,6 +14,7 @@ cx = gpu(x)
|
|||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
@test (cx .+ 1) isa CuArray
|
||||
|
||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
cm = gpu(m)
|
||||
|
@ -25,10 +26,13 @@ x = [1,2,3]
|
|||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
|
||||
# Fails in Pkg.test ffs
|
||||
# c = gpu(Conv((2,2),3=>4))
|
||||
# l = c(gpu(rand(10,10,3,2)))
|
||||
# Flux.back!(sum(l))
|
||||
xs = param(rand(5,5))
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
|
||||
end
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using Flux, CuArrays, Base.Test
|
||||
using Flux, CuArrays, Test
|
||||
|
||||
info("Testing Flux/CUDNN")
|
||||
@info "Testing Flux/CUDNN"
|
||||
|
||||
@testset "RNN" begin
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
using Flux.Data
|
||||
using Base.Test
|
||||
using Test
|
||||
|
||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||
|
||||
@test length(CMUDict.phones()) == 39
|
||||
|
||||
@test length(CMUDict.symbols()) == 84
|
||||
|
||||
@test MNIST.images()[1] isa Matrix
|
||||
@test MNIST.labels() isa Vector{Int64}
|
||||
|
|
|
@ -4,7 +4,7 @@ using Flux: testmode!
|
|||
x = [1.,2.,3.]
|
||||
@test x == testmode!(Dropout(0.1))(x)
|
||||
@test x == Dropout(0)(x)
|
||||
@test zeros(x) == Dropout(1)(x)
|
||||
@test zero(x) == Dropout(1)(x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
|
@ -53,17 +53,17 @@ end
|
|||
# .1 * 4 + 0 = .4
|
||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||
|
||||
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# 2×1 Array{Float64,2}:
|
||||
# 1.14495
|
||||
# 1.14495
|
||||
@test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
@test m.σ ≈ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test x′[1] ≈ (1 - 0.3) / 1.1449489742783179
|
||||
@test x′[1] ≈ (1 .- 0.3) / 1.1449489742783179
|
||||
end
|
||||
|
||||
# with activation function
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Base.Test
|
||||
using Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
|
||||
|
@ -42,8 +42,8 @@ const ϵ = 1e-7
|
|||
|
||||
logŷ, y = randn(3), rand(3)
|
||||
@testset "binarycrossentropy" begin
|
||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||
end
|
||||
|
||||
@testset "logitbinarycrossentropy" begin
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
using Flux:onecold
|
||||
using Test
|
||||
|
||||
@testset "onecold" begin
|
||||
a = [1, 2, 5, 3.]
|
||||
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
|
||||
labels = ['A', 'B', 'C', 'D']
|
||||
|
||||
@test onecold(a) == 3
|
||||
@test onecold(A) == [3, 1, 4]
|
||||
@test onecold(a, labels) == 'C'
|
||||
@test onecold(A, labels) == ['C', 'A', 'D']
|
||||
end
|
|
@ -1,6 +1,6 @@
|
|||
using Flux.Optimise
|
||||
using Flux.Tracker
|
||||
|
||||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||
|
@ -23,7 +23,7 @@ end
|
|||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
Iterators.repeated((), 100),
|
||||
()->(),
|
||||
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
||||
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
end
|
||||
|
|
|
@ -1,17 +1,37 @@
|
|||
using Flux, Base.Test
|
||||
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
|
||||
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
|
||||
if Base.JLOptions().check_bounds == 1
|
||||
file = @__FILE__
|
||||
run(```
|
||||
$(Base.julia_cmd())
|
||||
--color=$(Base.have_color ? "yes" : "no")
|
||||
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
|
||||
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
|
||||
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
|
||||
$(file)
|
||||
```)
|
||||
exit()
|
||||
end
|
||||
|
||||
srand(0)
|
||||
using Flux, Test, Random
|
||||
using Random
|
||||
|
||||
Random.seed!(0)
|
||||
|
||||
# So we can use the system CuArrays
|
||||
insert!(LOAD_PATH, 2, "@v#.#")
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
include("tracker.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
if Base.find_package("CuArrays") != nothing
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
|
|
|
@ -1,22 +1,27 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||
@testset "Tracker" begin
|
||||
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
|
@ -28,7 +33,6 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
|
@ -48,8 +52,8 @@ function promotiontest(f, A, B, C)
|
|||
end
|
||||
|
||||
@testset "concat" begin
|
||||
cat1(x...) = cat(1, x...)
|
||||
cat2(x...) = cat(2, x...)
|
||||
cat1(x...) = cat(x..., dims = 1)
|
||||
cat2(x...) = cat(x..., dims = 2)
|
||||
|
||||
@testset for vcatf in [vcat, cat1]
|
||||
@test gradtest(vcatf, rand(5), rand(3))
|
||||
|
@ -61,6 +65,7 @@ end
|
|||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||
end
|
||||
|
||||
|
||||
@testset for hcatf in [hcat, cat2]
|
||||
@test gradtest(hcatf, rand(5), rand(5))
|
||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||
|
@ -71,17 +76,17 @@ end
|
|||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
@test gradtest(catf, rand(5))
|
||||
@test gradtest(catf, rand(5)')
|
||||
@test gradtest(catf, rand(2,5))
|
||||
@test gradtest(catf, rand(2,5,3))
|
||||
end
|
||||
|
||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 3:5
|
||||
catdim = (x...) -> cat(dim, x...)
|
||||
catdim = (x...) -> cat(x..., dims = dim)
|
||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||
|
@ -89,12 +94,12 @@ end
|
|||
|
||||
@test !isa(vcat(rand(2)), TrackedArray)
|
||||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
||||
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||
|
@ -105,16 +110,14 @@ end
|
|||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
# TODO unreliable
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
|
@ -124,54 +127,54 @@ end
|
|||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@test gradtest(diagm, rand(3))
|
||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
||||
@test gradtest(vecnorm, rand(5))
|
||||
@test gradtest(norm, rand(5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
@ -211,14 +214,11 @@ end
|
|||
@testset "Fallbacks" begin
|
||||
xs = param([1 2; 3 4])
|
||||
@test similar(xs) isa Matrix{Float64}
|
||||
# Remove this test if we do LowerTriangular properly
|
||||
L = LowerTriangular(xs)
|
||||
@test L*L' isa Matrix{TrackedReal{Float64}}
|
||||
end
|
||||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
|
@ -231,6 +231,11 @@ Tracker.back!(b)
|
|||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
|
||||
@test Tracker.gradient(2, 3) do x, y
|
||||
xy = Tracker.collect([x, y])
|
||||
xy[1]*xy[2]
|
||||
end == (3, 2)
|
||||
end
|
||||
|
||||
# Gradient Hooks
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||
using Flux
|
||||
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
|
||||
using StatsBase: std
|
||||
using Random
|
||||
using Test
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
||||
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||
f()
|
||||
f()
|
||||
f()
|
||||
|
@ -13,7 +17,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
|||
|
||||
@testset "leading behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
||||
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||
f()
|
||||
@test length(a) == 1
|
||||
f()
|
||||
|
@ -25,7 +29,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
|||
|
||||
@testset "trailing behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=false, trailing=true)
|
||||
f = throttle(()->push!(a, time()), 1, leading=false, trailing=true)
|
||||
f()
|
||||
@test length(a) == 0
|
||||
f()
|
||||
|
@ -59,7 +63,7 @@ end
|
|||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
srand(0)
|
||||
Random.seed!(0)
|
||||
# initn() should yield a kernel with stddev ~= 1e-2
|
||||
v = initn(10, 10)
|
||||
@test std(v) > 0.9*1e-2
|
||||
|
|
Loading…
Reference in New Issue