Merge branch 'master' into HEAD
This commit is contained in:
commit
e6be639436
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ docs/build/
|
|||||||
docs/site/
|
docs/site/
|
||||||
docs/flux.css
|
docs/flux.css
|
||||||
deps
|
deps
|
||||||
|
Manifest.toml
|
||||||
|
13
.travis.yml
13
.travis.yml
@ -4,11 +4,16 @@ os:
|
|||||||
- linux
|
- linux
|
||||||
# - osx
|
# - osx
|
||||||
julia:
|
julia:
|
||||||
- 0.6
|
- 0.7
|
||||||
|
- 1.0
|
||||||
|
- nightly
|
||||||
# uncomment the following lines to override the default test script
|
# uncomment the following lines to override the default test script
|
||||||
script:
|
# script:
|
||||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- julia: nightly
|
||||||
after_success:
|
after_success:
|
||||||
- julia -e 'Pkg.add("Documenter")'
|
- julia -e 'Pkg.add("Documenter")'
|
||||||
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||||
|
5
REQUIRE
5
REQUIRE
@ -1,14 +1,15 @@
|
|||||||
julia 0.6.0
|
julia 0.7
|
||||||
Juno
|
Juno
|
||||||
MacroTools 0.3.3
|
MacroTools 0.3.3
|
||||||
NNlib
|
NNlib
|
||||||
Requires
|
Requires
|
||||||
Adapt
|
Adapt
|
||||||
GZip
|
CodecZlib
|
||||||
Colors
|
Colors
|
||||||
ZipFile
|
ZipFile
|
||||||
AbstractTrees
|
AbstractTrees
|
||||||
Reexport
|
Reexport
|
||||||
|
StatsBase
|
||||||
|
|
||||||
# AD
|
# AD
|
||||||
ForwardDiff 0.5.0
|
ForwardDiff 0.5.0
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
||||||
|
|
||||||
```
|
```
|
||||||
julia> using Flux: onehot
|
julia> using Flux: onehot, onecold
|
||||||
|
|
||||||
julia> onehot(:b, [:a, :b, :c])
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
|
|||||||
true
|
true
|
||||||
```
|
```
|
||||||
|
|
||||||
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
|
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> argmax(ans, [:a, :b, :c])
|
julia> onecold(ans, [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
|
|
||||||
julia> argmax([true, false, false], [:a, :b, :c])
|
julia> onecold([true, false, false], [:a, :b, :c])
|
||||||
:a
|
:a
|
||||||
|
|
||||||
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
|
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
```
|
```
|
||||||
|
|
||||||
## Batches
|
## Batches
|
||||||
|
|
||||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
|
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> using Flux: onehotbatch
|
julia> using Flux: onehotbatch
|
||||||
|
@ -134,7 +134,7 @@ All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around
|
|||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> x.tracker
|
julia> x.tracker
|
||||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||||
```
|
```
|
||||||
|
|
||||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
The `Tracker` stores the gradient of a given object, which we've seen before.
|
||||||
|
@ -211,7 +211,7 @@ m(5) # => 26
|
|||||||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
Flux.treelike(Affine)
|
Flux.@treelike Affine
|
||||||
```
|
```
|
||||||
|
|
||||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
__precompile__()
|
|
||||||
|
|
||||||
module Flux
|
module Flux
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
using Juno, Requires, Reexport
|
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||||
@ -12,7 +10,6 @@ export Chain, Dense, RNN, LSTM, GRU, Conv,
|
|||||||
params, mapleaves, cpu, gpu
|
params, mapleaves, cpu, gpu
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using NNlib: @fix
|
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
@ -37,6 +34,6 @@ include("layers/normalise.jl")
|
|||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
@require CuArrays include("cuda/cuda.jl")
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
module CUDA
|
module CUDA
|
||||||
|
|
||||||
using CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||||
|
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
|
||||||
|
using LinearAlgebra
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Nothing}
|
||||||
states::CuVector{UInt8}
|
states::CuVector{UInt8}
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
|
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
|
||||||
|
|
||||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||||
d = [C_NULL]
|
d = [C_NULL]
|
||||||
s = Csize_t[0]
|
s = Csize_t[0]
|
||||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
|
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
|
||||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
|
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||||
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
||||||
desc = DropoutDesc(d[], states)
|
desc = DropoutDesc(d[], states)
|
||||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
|
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||||
finalizer(desc, x ->
|
finalizer(desc) do x
|
||||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
return desc
|
return desc
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -43,10 +46,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|||||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||||
|
|
||||||
function params(w::CuVector, input, hidden, n = 1)
|
function params(w::CuVector, input, hidden, n = 1)
|
||||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
||||||
wx = slice(0, (input, hidden*n))
|
wx = slice(0, (input, hidden*n))
|
||||||
wh = slice(length(wx), (hidden, hidden*n))
|
wh = slice(length(wx), (hidden, hidden*n))
|
||||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
||||||
(wx, wh), bias
|
(wx, wh), bias
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -57,14 +60,14 @@ mutable struct RNNDesc{T}
|
|||||||
params::CuVector{T}
|
params::CuVector{T}
|
||||||
weights::NTuple{2,CuMatrix{T}}
|
weights::NTuple{2,CuMatrix{T}}
|
||||||
bias::CuVector{T}
|
bias::CuVector{T}
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Nothing}
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
||||||
|
|
||||||
function rnnParamSize(T, r, input)
|
function rnnParamSize(T, r, input)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
|
||||||
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||||
return Int(size[])÷sizeof(T)
|
return Int(size[])÷sizeof(T)
|
||||||
end
|
end
|
||||||
@ -74,26 +77,27 @@ ngates(r::RNNDesc) = ngates(r.mode)
|
|||||||
|
|
||||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||||
d = [C_NULL]
|
d = [C_NULL]
|
||||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
|
||||||
|
|
||||||
dropoutDesc = DropoutDesc(0)
|
dropoutDesc = DropoutDesc(0)
|
||||||
inputMode = LINEAR_INPUT
|
inputMode = LINEAR_INPUT
|
||||||
direction = UNIDIRECTIONAL
|
direction = UNIDIRECTIONAL
|
||||||
algo = RNN_ALGO_STANDARD
|
algo = RNN_ALGO_STANDARD
|
||||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
||||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||||
|
|
||||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||||
# TODO: avoid reserve allocation here
|
# TODO: avoid reserve allocation here
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd) do x
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
return rd
|
return rd
|
||||||
end
|
end
|
||||||
|
|
||||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
|
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
|
||||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||||
return Int(size[])
|
return Int(size[])
|
||||||
end
|
end
|
||||||
@ -110,7 +114,7 @@ getworkspace(r::RNNDesc, seqlen, xdesc) =
|
|||||||
|
|
||||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
|
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
|
||||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||||
return Int(size[])
|
return Int(size[])
|
||||||
end
|
end
|
||||||
@ -119,19 +123,19 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
|
|||||||
workspace, reserve=nothing) where T
|
workspace, reserve=nothing) where T
|
||||||
if reserve == nothing
|
if reserve == nothing
|
||||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T},
|
Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Csize_t),
|
Ptr{Nothing}, Csize_t),
|
||||||
libcudnn_handle[], rnn, seqlen,
|
libcudnn_handle[], rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace))
|
workspace, length(workspace))
|
||||||
else
|
else
|
||||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||||
libcudnn_handle[], rnn, seqlen,
|
libcudnn_handle[], rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace), reserve, length(reserve))
|
workspace, length(workspace), reserve, length(reserve))
|
||||||
@ -140,7 +144,7 @@ end
|
|||||||
|
|
||||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||||
|
|
||||||
hDesc(h::Void) = C_NULL, C_NULL
|
hDesc(h::Nothing) = C_NULL, C_NULL
|
||||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||||
function hDesc(h::CuArray)
|
function hDesc(h::CuArray)
|
||||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||||
@ -187,18 +191,18 @@ forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T
|
|||||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
|
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
||||||
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||||
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||||
end
|
end
|
||||||
|
|
||||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||||
# Same as above, any more efficient way?
|
# Same as above, any more efficient way?
|
||||||
dy = dy_ isa Integer ? zeros(y) : dy_
|
dy = dy_ isa Integer ? zero(y) : dy_
|
||||||
yd = xDesc(y)
|
yd = xDesc(y)
|
||||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||||
dh = similar(h)
|
dh = similar(h)
|
||||||
@ -217,19 +221,19 @@ backwardData(rnn, y, dy, dho, hx, reserve) =
|
|||||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||||
workspace, reserve) where T
|
workspace, reserve) where T
|
||||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
|
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, #x
|
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
||||||
Ptr{Void}, Ptr{T}, #hx
|
Ptr{Nothing}, Ptr{T}, #hx
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, #y
|
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
||||||
Ptr{Void}, Csize_t, #ws
|
Ptr{Nothing}, Csize_t, #ws
|
||||||
Ptr{Void}, Ptr{T}, #dw
|
Ptr{Nothing}, Ptr{T}, #dw
|
||||||
Ptr{Void}, Csize_t), #rs
|
Ptr{Nothing}, Csize_t), #rs
|
||||||
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
||||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||||
end
|
end
|
||||||
|
|
||||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||||
dw = zeros(rnn.params)
|
dw = zero(rnn.params)
|
||||||
cudnnRNNBackwardWeights(rnn, 1,
|
cudnnRNNBackwardWeights(rnn, 1,
|
||||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||||
@ -241,17 +245,17 @@ end
|
|||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Tracker: TrackedArray
|
import ..Tracker: TrackedArray
|
||||||
using CUDAnative
|
using .CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using .CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
I = @cuindex dst
|
I = @cuindex dst
|
||||||
dst[I...] = src[reverse(I)...]
|
dst[I...] = src[reverse(I)...]
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
blk, thr = cudims(dst)
|
blk, thr = cudims(dst)
|
||||||
@cuda (blk, thr) kernel(dst, src)
|
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||||
return dst
|
return dst
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -324,7 +328,7 @@ end
|
|||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -339,6 +343,6 @@ end
|
|||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
nobacksies(:RNN,
|
||||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||||
dWi.', dWh.', db))
|
transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -11,6 +11,8 @@ function __init__()
|
|||||||
end
|
end
|
||||||
|
|
||||||
include("mnist.jl")
|
include("mnist.jl")
|
||||||
|
export MNIST
|
||||||
|
|
||||||
include("cmudict.jl")
|
include("cmudict.jl")
|
||||||
using .CMUDict
|
using .CMUDict
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ function load()
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
info("Downloading CMUDict dataset")
|
@info "Downloading CMUDict dataset"
|
||||||
mkpath(deps("cmudict"))
|
mkpath(deps("cmudict"))
|
||||||
for x in suffixes
|
for x in suffixes
|
||||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||||
@ -24,25 +24,25 @@ end
|
|||||||
|
|
||||||
function phones()
|
function phones()
|
||||||
load()
|
load()
|
||||||
Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")),
|
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
||||||
"\n", keep = false), "\t")))
|
"\n", keepempty = false), "\t")))
|
||||||
end
|
end
|
||||||
|
|
||||||
function symbols()
|
function symbols()
|
||||||
load()
|
load()
|
||||||
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
|
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
||||||
"\n", keep = false))
|
"\n", keepempty = false))
|
||||||
end
|
end
|
||||||
|
|
||||||
function rawdict()
|
function rawdict()
|
||||||
load()
|
load()
|
||||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||||
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
|
filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n"))))
|
||||||
end
|
end
|
||||||
|
|
||||||
validword(s) = ismatch(r"^[\w\-\.]+$", s)
|
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
||||||
|
|
||||||
cmudict() = filter((s, ps) -> validword(s), rawdict())
|
cmudict() = filter(p -> validword(p.first), rawdict())
|
||||||
|
|
||||||
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
||||||
|
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
module MNIST
|
module MNIST
|
||||||
|
|
||||||
using GZip, Colors
|
using CodecZlib, Colors
|
||||||
|
|
||||||
const Gray = Colors.Gray{Colors.N0f8}
|
const Gray = Colors.Gray{Colors.N0f8}
|
||||||
|
|
||||||
const dir = joinpath(@__DIR__, "../../deps/mnist")
|
const dir = joinpath(@__DIR__, "../../deps/mnist")
|
||||||
|
|
||||||
|
function gzopen(f, file)
|
||||||
|
open(file) do io
|
||||||
|
f(GzipDecompressorStream(io))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function load()
|
function load()
|
||||||
mkpath(dir)
|
mkpath(dir)
|
||||||
cd(dir) do
|
cd(dir) do
|
||||||
@ -14,10 +20,10 @@ function load()
|
|||||||
"t10k-images-idx3-ubyte",
|
"t10k-images-idx3-ubyte",
|
||||||
"t10k-labels-idx1-ubyte"]
|
"t10k-labels-idx1-ubyte"]
|
||||||
isfile(file) && continue
|
isfile(file) && continue
|
||||||
info("Downloading MNIST dataset")
|
@info "Downloading MNIST dataset"
|
||||||
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||||
open(file, "w") do io
|
open(file, "w") do io
|
||||||
write(io, GZip.open(read, "$file.gz"))
|
write(io, gzopen(read, "$file.gz"))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -49,7 +55,7 @@ function labelheader(io::IO)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function rawimage(io::IO)
|
function rawimage(io::IO)
|
||||||
img = Array{Gray}(NCOLS, NROWS)
|
img = Array{Gray}(undef, NCOLS, NROWS)
|
||||||
for i in 1:NCOLS, j in 1:NROWS
|
for i in 1:NCOLS, j in 1:NROWS
|
||||||
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
||||||
end
|
end
|
||||||
|
@ -5,7 +5,7 @@ using ..Data: deps
|
|||||||
|
|
||||||
function load()
|
function load()
|
||||||
isfile(deps("sentiment.zip")) || return
|
isfile(deps("sentiment.zip")) || return
|
||||||
info("Downloading sentiment treebank dataset")
|
@info "Downloading sentiment treebank dataset"
|
||||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||||
deps("sentiment.zip"))
|
deps("sentiment.zip"))
|
||||||
end
|
end
|
||||||
@ -14,7 +14,7 @@ getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
|||||||
|
|
||||||
function getfile(name)
|
function getfile(name)
|
||||||
r = ZipFile.Reader(deps("sentiment.zip"))
|
r = ZipFile.Reader(deps("sentiment.zip"))
|
||||||
text = readstring(getfile(r, "trees/$name"))
|
text = read(getfile(r, "trees/$name"), String)
|
||||||
close(r)
|
close(r)
|
||||||
return text
|
return text
|
||||||
end
|
end
|
||||||
@ -29,12 +29,12 @@ function parsetree(s)
|
|||||||
s = replace(s, r"\$", s -> "\\\$")
|
s = replace(s, r"\$", s -> "\\\$")
|
||||||
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
||||||
s = replace(s, " ", ", ")
|
s = replace(s, " ", ", ")
|
||||||
return totree(parse(s))
|
return totree(Meta.parse(s))
|
||||||
end
|
end
|
||||||
|
|
||||||
function gettrees(name)
|
function gettrees(name)
|
||||||
load()
|
load()
|
||||||
ss = split(getfile("$name.txt"), '\n', keep = false)
|
ss = split(getfile("$name.txt"), '\n', keepempty = false)
|
||||||
return parsetree.(ss)
|
return parsetree.(ss)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -16,19 +16,19 @@ m(x) == m[2](m[1](x))
|
|||||||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||||
`m[1:3](x)` will calculate the output of the first three layers.
|
`m[1:3](x)` will calculate the output of the first three layers.
|
||||||
"""
|
"""
|
||||||
type Chain
|
struct Chain
|
||||||
layers::Vector{Any}
|
layers::Vector{Any}
|
||||||
Chain(xs...) = new([xs...])
|
Chain(xs...) = new([xs...])
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
|
||||||
@forward Chain.layers Base.start, Base.next, Base.done
|
@forward Chain.layers Base.iterate
|
||||||
|
|
||||||
children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||||
|
|
||||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
@ -38,10 +38,7 @@ function Base.show(io::IO, c::Chain)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
# Seem to need this for `accumulate`; try removing on 0.7
|
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
||||||
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
|
|
||||||
|
|
||||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dense(in::Integer, out::Integer, σ = identity)
|
Dense(in::Integer, out::Integer, σ = identity)
|
||||||
@ -76,11 +73,11 @@ function Dense(in::Integer, out::Integer, σ = identity;
|
|||||||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||||
end
|
end
|
||||||
|
|
||||||
treelike(Dense)
|
@treelike Dense
|
||||||
|
|
||||||
function (a::Dense)(x)
|
function (a::Dense)(x)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
@fix σ.(W*x .+ b)
|
σ.(W*x .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Dense)
|
function Base.show(io::IO, l::Dense)
|
||||||
@ -107,7 +104,7 @@ end
|
|||||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||||
Diagonal(param(initα(in)), param(initβ(in)))
|
Diagonal(param(initα(in)), param(initβ(in)))
|
||||||
|
|
||||||
treelike(Diagonal)
|
@treelike Diagonal
|
||||||
|
|
||||||
function (a::Diagonal)(x)
|
function (a::Diagonal)(x)
|
||||||
α, β = a.α, a.β
|
α, β = a.α, a.β
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
|
||||||
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
@generated sub2(::Type{Val{N}}) where N = :(Val($(N-2)))
|
||||||
|
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
@ -35,7 +35,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
|
|||||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
Flux.treelike(Conv)
|
@treelike Conv
|
||||||
|
|
||||||
function (c::Conv)(x)
|
function (c::Conv)(x)
|
||||||
# TODO: breaks gpu broadcast :(
|
# TODO: breaks gpu broadcast :(
|
||||||
|
@ -58,7 +58,7 @@ end
|
|||||||
LayerNorm(h::Integer) =
|
LayerNorm(h::Integer) =
|
||||||
LayerNorm(Diagonal(h))
|
LayerNorm(Diagonal(h))
|
||||||
|
|
||||||
treelike(LayerNorm)
|
@treelike LayerNorm
|
||||||
|
|
||||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ mutable struct BatchNorm{F,V,W,N}
|
|||||||
end
|
end
|
||||||
|
|
||||||
BatchNorm(chs::Integer, λ = identity;
|
BatchNorm(chs::Integer, λ = identity;
|
||||||
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
|
||||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||||
|
|
||||||
@ -130,13 +130,13 @@ function (BN::BatchNorm)(x)
|
|||||||
|
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
ϵ = data(convert(T, BN.ϵ))
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, axes)
|
μ = mean(x, dims = axes)
|
||||||
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
|
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
|
||||||
|
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = data(convert(T, BN.momentum))
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
|
||||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
gate(h, n) = (1:h) + h*(n-1)
|
gate(h, n) = (1:h) .+ h*(n-1)
|
||||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ function (m::Recur)(xs...)
|
|||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
treelike(Recur, (:cell, :init))
|
@treelike Recur cell, init
|
||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ end
|
|||||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||||
init = glorot_uniform) =
|
init = glorot_uniform) =
|
||||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||||
param(zeros(out)), param(initn(out)))
|
param(zeros(out)), param(init(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||||
@ -94,7 +94,7 @@ end
|
|||||||
|
|
||||||
hidden(m::RNNCell) = m.h
|
hidden(m::RNNCell) = m.h
|
||||||
|
|
||||||
treelike(RNNCell)
|
@treelike RNNCell
|
||||||
|
|
||||||
function Base.show(io::IO, l::RNNCell)
|
function Base.show(io::IO, l::RNNCell)
|
||||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||||
@ -123,8 +123,8 @@ end
|
|||||||
function LSTMCell(in::Integer, out::Integer;
|
function LSTMCell(in::Integer, out::Integer;
|
||||||
init = glorot_uniform)
|
init = glorot_uniform)
|
||||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||||
param(initn(out)), param(initn(out)))
|
param(init(out)), param(init(out)))
|
||||||
cell.b.data[gate(out, 2)] = 1
|
cell.b.data[gate(out, 2)] .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ end
|
|||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
treelike(LSTMCell)
|
@treelike LSTMCell
|
||||||
|
|
||||||
Base.show(io::IO, l::LSTMCell) =
|
Base.show(io::IO, l::LSTMCell) =
|
||||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||||
@ -170,7 +170,7 @@ end
|
|||||||
|
|
||||||
GRUCell(in, out; init = glorot_uniform) =
|
GRUCell(in, out; init = glorot_uniform) =
|
||||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||||
param(zeros(out*3)), param(initn(out)))
|
param(zeros(out*3)), param(init(out)))
|
||||||
|
|
||||||
function (m::GRUCell)(h, x)
|
function (m::GRUCell)(h, x)
|
||||||
b, o = m.b, size(h, 1)
|
b, o = m.b, size(h, 1)
|
||||||
@ -178,13 +178,13 @@ function (m::GRUCell)(h, x)
|
|||||||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||||
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
||||||
h′ = (1.-z).*h̃ .+ z.*h
|
h′ = (1 .- z).*h̃ .+ z.*h
|
||||||
return h′, h′
|
return h′, h′
|
||||||
end
|
end
|
||||||
|
|
||||||
hidden(m::GRUCell) = m.h
|
hidden(m::GRUCell) = m.h
|
||||||
|
|
||||||
treelike(GRUCell)
|
@treelike GRUCell
|
||||||
|
|
||||||
Base.show(io::IO, l::GRUCell) =
|
Base.show(io::IO, l::GRUCell) =
|
||||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||||
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
@ -32,20 +32,21 @@ import Adapt.adapt
|
|||||||
|
|
||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
@require CuArrays begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
import CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
function onehot(l, labels)
|
function onehot(l, labels)
|
||||||
i = findfirst(labels, l)
|
i = something(findfirst(isequal(l), labels), 0)
|
||||||
i > 0 || error("Value $l is not in labels")
|
i > 0 || error("Value $l is not in labels")
|
||||||
OneHotVector(i, length(labels))
|
OneHotVector(i, length(labels))
|
||||||
end
|
end
|
||||||
|
|
||||||
function onehot(l, labels, unk)
|
function onehot(l, labels, unk)
|
||||||
i = findfirst(labels, l)
|
i = something(findfirst(isequal(l), labels), 0)
|
||||||
i > 0 || return onehot(unk, labels)
|
i > 0 || return onehot(unk, labels)
|
||||||
OneHotVector(i, length(labels))
|
OneHotVector(i, length(labels))
|
||||||
end
|
end
|
||||||
@ -53,11 +54,15 @@ end
|
|||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||||
labels[findfirst(y, maximum(y))]
|
|
||||||
|
|
||||||
argmax(y::AbstractMatrix, l...) =
|
onecold(y::AbstractMatrix, labels...) =
|
||||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||||
|
|
||||||
|
function argmax(xs...)
|
||||||
|
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||||
|
return onecold(xs...)
|
||||||
|
end
|
||||||
|
|
||||||
# Ambiguity hack
|
# Ambiguity hack
|
||||||
|
|
||||||
|
@ -2,14 +2,14 @@ module Optimise
|
|||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
Δ::T
|
Δ::T
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
Param(x::AbstractArray) = Param(x, zero(x))
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("interface.jl")
|
include("interface.jl")
|
||||||
@ -17,6 +17,7 @@ include("train.jl")
|
|||||||
|
|
||||||
using Flux.Tracker: TrackedArray
|
using Flux.Tracker: TrackedArray
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
Param(x::TrackedArray) = Param(x.data, x.grad)
|
||||||
|
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -14,7 +14,7 @@ function descentweightdecay(p::Param, η::Real, γ::Real)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function momentum(p::Param, ρ, η)
|
function momentum(p::Param, ρ, η)
|
||||||
v = zeros(p.x)
|
v = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. v = ρ * v - η * p.Δ
|
@. v = ρ * v - η * p.Δ
|
||||||
@. p.Δ = -v
|
@. p.Δ = -v
|
||||||
@ -23,7 +23,7 @@ end
|
|||||||
|
|
||||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
||||||
function nesterov(p::Param, ρ, η)
|
function nesterov(p::Param, ρ, η)
|
||||||
v = zeros(p.x)
|
v = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
||||||
@. v = ρ*v - η*p.Δ
|
@. v = ρ*v - η*p.Δ
|
||||||
@ -32,7 +32,7 @@ function nesterov(p::Param, ρ, η)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x)
|
acc = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
@ -40,7 +40,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x) .+ ϵ
|
acc = zero(p.x) .+ ϵ
|
||||||
function ()
|
function ()
|
||||||
@. acc += p.Δ^2
|
@. acc += p.Δ^2
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
@ -48,8 +48,8 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x)
|
acc = zero(p.x)
|
||||||
Δacc = zeros(p.x)
|
Δacc = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||||
@ -58,8 +58,8 @@ function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
vt = zeros(p.x)
|
vt = zero(p.x)
|
||||||
β1p, β2p = β1, β2
|
β1p, β2p = β1, β2
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@ -71,8 +71,8 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
ut = zeros(p.x)
|
ut = zero(p.x)
|
||||||
β1p = β1
|
β1p = β1
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@ -83,9 +83,9 @@ function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999,
|
|||||||
end
|
end
|
||||||
|
|
||||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
vt = zeros(p.x) .+ ϵ
|
vt = zero(p.x) .+ ϵ
|
||||||
v̂t = zeros(p.x) .+ ϵ
|
v̂t = zero(p.x) .+ ϵ
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||||
@ -95,8 +95,8 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||||||
end
|
end
|
||||||
|
|
||||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
vt = zeros(p.x)
|
vt = zero(p.x)
|
||||||
β1p, β2p = β1, β2
|
β1p, β2p = β1, β2
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!
|
using Flux.Tracker: back!
|
||||||
|
import Base.depwarn
|
||||||
|
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
@ -14,6 +15,25 @@ macro interrupts(ex)
|
|||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
struct StopException <: Exception end
|
||||||
|
"""
|
||||||
|
stop()
|
||||||
|
|
||||||
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||||
|
This would trigger the train loop to stop and exit.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
# Example callback:
|
||||||
|
|
||||||
|
cb = function ()
|
||||||
|
accuracy() > 0.9 && Flux.stop()
|
||||||
|
end
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
function stop()
|
||||||
|
throw(StopException())
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, data, opt)
|
train!(loss, data, opt)
|
||||||
|
|
||||||
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
try
|
||||||
@interrupts back!(l)
|
l = loss(d...)
|
||||||
opt()
|
@interrupts back!(l)
|
||||||
cb() == :stop && break
|
opt()
|
||||||
|
if cb() == :stop
|
||||||
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
|
break
|
||||||
|
end
|
||||||
|
catch ex
|
||||||
|
if ex isa StopException
|
||||||
|
break
|
||||||
|
else
|
||||||
|
rethrow(ex)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -59,7 +90,7 @@ hello
|
|||||||
"""
|
"""
|
||||||
macro epochs(n, ex)
|
macro epochs(n, ex)
|
||||||
:(@progress for i = 1:$(esc(n))
|
:(@progress for i = 1:$(esc(n))
|
||||||
info("Epoch $i")
|
@info "Epoch $i"
|
||||||
$(esc(ex))
|
$(esc(ex))
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
@ -12,7 +12,7 @@ tracker(x) = nothing
|
|||||||
istracked(x) = tracker(x) ≠ nothing
|
istracked(x) = tracker(x) ≠ nothing
|
||||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||||
grad(x) = grad(tracker(x))
|
grad(x) = grad(tracker(x))
|
||||||
grad(::Void) = nothing
|
grad(::Nothing) = nothing
|
||||||
data(x) = x
|
data(x) = x
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
@ -35,7 +35,7 @@ mutable struct Tracked{T}
|
|||||||
grad::T
|
grad::T
|
||||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||||
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
|
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
istracked(x::Tracked) = true
|
istracked(x::Tracked) = true
|
||||||
@ -46,14 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
|||||||
|
|
||||||
function _forward end
|
function _forward end
|
||||||
|
|
||||||
function track(f::F, xs...) where F
|
function track(f::F, xs...; kw...) where F
|
||||||
y, back = _forward(f, xs...)
|
|
||||||
ts = map(tracker, xs)
|
|
||||||
c = Call(back, ts)
|
|
||||||
track(c, y)
|
|
||||||
end
|
|
||||||
|
|
||||||
function track_kw(f::F, xs...; kw...) where F
|
|
||||||
y, back = _forward(f, xs...; kw...)
|
y, back = _forward(f, xs...; kw...)
|
||||||
track(Call(back, tracker.(xs)), y)
|
track(Call(back, tracker.(xs)), y)
|
||||||
end
|
end
|
||||||
@ -84,10 +77,9 @@ include("numeric.jl")
|
|||||||
|
|
||||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||||
the sign of the gradient applied to `x`.
|
the sign of the gradient applied to `x`."""
|
||||||
"""
|
|
||||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
checkpoint(f, args...)
|
checkpoint(f, args...)
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
|
import Base: *, ==
|
||||||
|
|
||||||
|
import LinearAlgebra
|
||||||
|
using Statistics
|
||||||
|
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||||
|
|
||||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||||
tracker::Tracked{A}
|
tracker::Tracked{A}
|
||||||
data::A
|
data::A
|
||||||
@ -21,24 +27,20 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||||
|
|
||||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
print(io, "TrackedArray{…,$A}")
|
print(io, "TrackedArray{…,$A}")
|
||||||
|
|
||||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
function Base.summary(io::IO, x::TrackedArray)
|
||||||
if repr
|
print(io, "Tracked ")
|
||||||
print(io, "param(")
|
summary(io, data(x))
|
||||||
Base.showarray(io, data(X), true)
|
|
||||||
print(io, ")")
|
|
||||||
else
|
|
||||||
header && print(io, "Tracked ")
|
|
||||||
Base.showarray(io, data(X), false, header = header)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||||
|
|
||||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||||
error("Can't differentiate `setindex!`")
|
error("Can't differentiate `setindex!`")
|
||||||
|
|
||||||
@ -46,7 +48,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
|
|||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims].args
|
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -58,9 +60,9 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
x::TrackedArray == y = data(x) == y
|
||||||
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
y == x::TrackedArray = y == data(x)
|
||||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
x::TrackedArray == y::TrackedArray = data(x) == data(y)
|
||||||
|
|
||||||
# Array Stdlib
|
# Array Stdlib
|
||||||
|
|
||||||
@ -79,37 +81,20 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
|||||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||||
|
|
||||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||||
|
|
||||||
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
|
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||||
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||||
|
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
|
||||||
|
|
||||||
@grad function repmat(xs, m, n = 1)
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||||
repmat(data(xs), m, n), function (Δ)
|
|
||||||
Δ′ = similar(xs)
|
|
||||||
S = size(xs)
|
|
||||||
for (i,v) in enumerate(data(Δ))
|
|
||||||
d1 = divrem(i-1, S[1]*m)
|
|
||||||
x = d1[2] % S[1]+1
|
|
||||||
y = d1[1] % S[2]+1
|
|
||||||
Δ′[x, y] += v
|
|
||||||
end
|
|
||||||
return (nobacksies(:repmat, Δ′), nothing, nothing)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
|
||||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
S = size(xs)
|
S = size(xs)
|
||||||
|
|
||||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||||
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
|
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||||
# wrap around based on original size S.
|
# wrap around based on original size S.
|
||||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||||
@ -119,8 +104,8 @@ Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
for f in [:vcat, :hcat]
|
for f in [:vcat, :hcat]
|
||||||
|
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||||
@eval begin
|
@eval begin
|
||||||
# This section is a bit of a hack since julia doesn't have a standardised
|
# This section is a bit of a hack since julia doesn't have a standardised
|
||||||
# promotion mechanism for concatenation yet
|
# promotion mechanism for concatenation yet
|
||||||
@ -129,18 +114,18 @@ for f in [:vcat, :hcat]
|
|||||||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||||
# TrackedArray anywhere among the arguments This works as long as base has
|
# TrackedArray anywhere among the arguments This works as long as base has
|
||||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||||
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
Base.$f(a::$UArray...) = track($f, a...)
|
||||||
|
|
||||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||||
# first
|
# first
|
||||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||||
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||||
|
|
||||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||||
# second
|
# second
|
||||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||||
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
||||||
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
c::$UArray...) =
|
||||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -175,21 +160,23 @@ end
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
||||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||||
|
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||||
|
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||||
|
|
||||||
@grad function cat(dims, Xs...)
|
@grad function cat(Xs...; dims)
|
||||||
cat(dims, data.(Xs)...), function (Δ)
|
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
start = ntuple(i -> 0, Val(ndims(Δ)))
|
||||||
Δs = [begin
|
Δs = [begin
|
||||||
dim_xs = 1:ndims(xs)
|
dim_xs = 1:ndims(xs)
|
||||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
||||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
||||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||||
start = start .+ till_xs
|
start = start .+ till_xs
|
||||||
d
|
d
|
||||||
end for xs in Xs]
|
end for xs in Xs]
|
||||||
return (nothing, Δs...,)
|
return (Δs...,)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -218,98 +205,86 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
|
||||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||||
|
|
||||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
Δ -> (zero(xs) .+ Δ, )
|
||||||
|
|
||||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||||
|
|
||||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||||
Δ -> (nobacksies(:sum,
|
Δ -> (nobacksies(:sum,
|
||||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||||
nothing)
|
nothing)
|
||||||
|
|
||||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||||
|
|
||||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
|
||||||
|
|
||||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
|
||||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
|
||||||
|
|
||||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
import LinearAlgebra: dot
|
||||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
||||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||||
|
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||||
|
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||||
|
|
||||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||||
|
|
||||||
# Hacks to get std working
|
# Hacks to get std working
|
||||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||||
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
|
||||||
|
|
||||||
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||||
|
|
||||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||||
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||||
|
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||||
|
|
||||||
@grad function maximum(xs, r...)
|
@grad function maximum(xs; dims = dims)
|
||||||
maximum(data(xs), r...), function (Δ)
|
maximum(data(xs), dims = dims), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
_, i = findmax(data(xs), r...)
|
_, i = findmax(data(xs), dims = dims)
|
||||||
Δ′[i] = data(Δ)
|
Δ′[i] = data(Δ)
|
||||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
return (nobacksies(:maximum, Δ′),)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@grad function minimum(xs, r...)
|
|
||||||
minimum(data(xs), r...), function (Δ)
|
@grad function minimum(xs; dims = dims)
|
||||||
|
minimum(data(xs), dims = dims), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
_, i = findmin(data(xs), r...)
|
_, i = findmin(data(xs), dims = dims)
|
||||||
Δ′[i] = data(Δ)
|
Δ′[i] = data(Δ)
|
||||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
return (nobacksies(:minimum, Δ′),)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# BLAS
|
# BLAS
|
||||||
|
|
||||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
||||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||||
|
|
||||||
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
|
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||||
@eval begin
|
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||||
import Base.$f
|
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
|
||||||
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
|
||||||
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
|
||||||
|
|
||||||
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
||||||
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
||||||
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
||||||
|
|
||||||
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
||||||
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
||||||
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||||
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||||
|
|
||||||
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
|
||||||
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
|
||||||
|
|
||||||
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
|
||||||
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
|
||||||
|
|
||||||
# NNlib
|
# NNlib
|
||||||
|
|
||||||
@ -324,9 +299,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||||||
|
|
||||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||||
|
|
||||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||||
|
|
||||||
@grad conv(x, w; kw...) =
|
@grad conv(x, w; kw...) =
|
||||||
conv(data(x), data(w); kw...),
|
conv(data(x), data(w); kw...),
|
||||||
@ -334,14 +309,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
|||||||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||||
|
|
||||||
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
|
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||||
|
|
||||||
@grad function maxpool(x, k; kw...)
|
@grad function maxpool(x, k; kw...)
|
||||||
y = maxpool(data(x), k; kw...)
|
y = maxpool(data(x), k; kw...)
|
||||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||||
end
|
end
|
||||||
|
|
||||||
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
|
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||||
|
|
||||||
@grad function meanpool(x, k; kw...)
|
@grad function meanpool(x, k; kw...)
|
||||||
y = meanpool(data(x), k; kw...)
|
y = meanpool(data(x), k; kw...)
|
||||||
@ -352,45 +327,85 @@ end
|
|||||||
|
|
||||||
using ForwardDiff: Dual, partials, value
|
using ForwardDiff: Dual, partials, value
|
||||||
|
|
||||||
dualify(xs, n) = xs
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
|
||||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
|
||||||
|
|
||||||
unbroadcast(x::Tuple, Δ) =
|
unbroadcast(x::AbstractArray, Δ) =
|
||||||
x == size(Δ) ? Δ :
|
size(x) == size(Δ) ? Δ :
|
||||||
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
length(x) == length(Δ) ? trim(x, Δ) :
|
||||||
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
|
|
||||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||||
|
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
||||||
|
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
||||||
|
|
||||||
function getpartial(Δ, x, i)
|
dual(x, p) = x
|
||||||
@inbounds p = getindex(partials(x), i)
|
dual(x::Real, p) = Dual(x, p)
|
||||||
return Δ * p
|
|
||||||
|
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||||
|
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||||
|
return Δ * f(dargs...).partials[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||||
sizes = size.(args)
|
y = broadcast(f, data.(args)...)
|
||||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
eltype(y) <: Real || return y
|
||||||
out = broadcast(f, dargs...)
|
eltype(y) == Bool && return y
|
||||||
eltype(out) <: Dual || return out
|
function back(Δ)
|
||||||
y = value.(out)
|
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
||||||
back = function (Δ_)
|
dxs = unbroadcast.(args, Δargs)
|
||||||
Δ = data(Δ_)
|
return nobacksies(:broadcast, dxs)
|
||||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
|
||||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
|
||||||
nobacksies(:broadcast, dxs)
|
|
||||||
end
|
end
|
||||||
# So we can return non-tracked arrays
|
# So we can return non-tracked arrays
|
||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
|
||||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
|
||||||
|
|
||||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
|
struct TrackedStyle <: BroadcastStyle end
|
||||||
|
|
||||||
|
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||||
|
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||||
|
|
||||||
|
# We have to re-build the original broadcast struct to get the appropriate array
|
||||||
|
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||||
|
broadcast_rebuild(xs) = data(xs)
|
||||||
|
|
||||||
|
broadcast_rebuild(bc::Broadcasted) =
|
||||||
|
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||||
|
|
||||||
|
preprocess(x) = x
|
||||||
|
|
||||||
|
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||||
|
bc1 = Broadcast.flatten(bc)
|
||||||
|
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||||
|
∇broadcast(bc2.f, bc1.args...)
|
||||||
|
end
|
||||||
|
|
||||||
|
using Requires
|
||||||
|
|
||||||
|
# https://github.com/FluxML/Flux.jl/issues/353
|
||||||
|
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||||
|
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||||
|
isflat(bc) && return bc
|
||||||
|
args = cat_nested(bc)
|
||||||
|
let makeargs = make_makeargs(bc), f = bc.f
|
||||||
|
newf = @inline function(args::Vararg{Any,N}) where N
|
||||||
|
f(makeargs(args...)...)
|
||||||
|
end
|
||||||
|
return Broadcasted{Style}(newf, args, bc.axes)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||||
|
bc = t[1]
|
||||||
|
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||||
|
let makeargs = make_makeargs(makeargs, bc.args)
|
||||||
|
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||||
|
return @inline function(args::Vararg{Any,N}) where N
|
||||||
|
args1 = makeargs(args...)
|
||||||
|
a, b = headargs(args1...), tailargs(args1...)
|
||||||
|
(f(a...), b...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -26,7 +26,7 @@ function back_(c::Call, Δ)
|
|||||||
foreach(back, c.args, data.(Δs))
|
foreach(back, c.args, data.(Δs))
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(::Call{Void}, Δ) = nothing
|
back_(::Call{Nothing}, Δ) = nothing
|
||||||
|
|
||||||
accum!(x, Δ) = x .+ Δ
|
accum!(x, Δ) = x .+ Δ
|
||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
@ -47,7 +47,7 @@ function back(x::Tracked, Δ)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(::Void, _) = return
|
back(::Nothing, _) = return
|
||||||
|
|
||||||
# Interface methods
|
# Interface methods
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ struct Params
|
|||||||
Params(xs) = new(IdSet(xs))
|
Params(xs) = new(IdSet(xs))
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Params.params Base.start, Base.next, Base.done
|
@forward Params.params Base.iterate, Base.length
|
||||||
|
|
||||||
function Base.show(io::IO, ps::Params)
|
function Base.show(io::IO, ps::Params)
|
||||||
print(io, "Params([")
|
print(io, "Params([")
|
||||||
@ -79,14 +79,16 @@ function Base.show(io::IO, ps::Params)
|
|||||||
end
|
end
|
||||||
|
|
||||||
struct Grads
|
struct Grads
|
||||||
grads::ObjectIdDict
|
grads::IdDict{Any,Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||||
|
|
||||||
Grads() = Grads(ObjectIdDict())
|
Grads() = Grads(IdDict())
|
||||||
|
|
||||||
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||||
|
|
||||||
|
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||||
|
|
||||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||||
function Base.getindex(g::Grads, x)
|
function Base.getindex(g::Grads, x)
|
||||||
@ -94,9 +96,8 @@ function Base.getindex(g::Grads, x)
|
|||||||
g[tracker(x)]
|
g[tracker(x)]
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Grads.grads Base.setindex!, Base.haskey
|
|
||||||
|
|
||||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||||
|
|
||||||
function back_(g::Grads, c::Call, Δ)
|
function back_(g::Grads, c::Call, Δ)
|
||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
@ -105,7 +106,7 @@ function back_(g::Grads, c::Call, Δ)
|
|||||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||||
|
|
||||||
function back(g::Grads, x::Tracked, Δ)
|
function back(g::Grads, x::Tracked, Δ)
|
||||||
x.isleaf && (accum!(g, x, Δ); return)
|
x.isleaf && (accum!(g, x, Δ); return)
|
||||||
@ -119,7 +120,7 @@ function back(g::Grads, x::Tracked, Δ)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(::Grads, ::Void, _) = return
|
back(::Grads, ::Nothing, _) = return
|
||||||
|
|
||||||
function forward(f, ps::Params)
|
function forward(f, ps::Params)
|
||||||
y = f()
|
y = f()
|
||||||
@ -136,7 +137,7 @@ end
|
|||||||
function forward(f, args...)
|
function forward(f, args...)
|
||||||
args = param.(args)
|
args = param.(args)
|
||||||
y, back = forward(() -> f(args...), Params(args))
|
y, back = forward(() -> f(args...), Params(args))
|
||||||
y, Δ -> getindex.(back(Δ), args)
|
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||||
end
|
end
|
||||||
|
|
||||||
function losscheck(x)
|
function losscheck(x)
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
struct IdSet{T} <: AbstractSet{T}
|
struct IdSet{T} <: AbstractSet{T}
|
||||||
dict::ObjectIdDict
|
dict::IdDict{T,Nothing}
|
||||||
IdSet{T}() where T = new(ObjectIdDict())
|
IdSet{T}() where T = new(IdDict{T,Nothing}())
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.eltype{T}(::IdSet{T}) = T
|
Base.eltype(::IdSet{T}) where T = T
|
||||||
|
|
||||||
IdSet() = IdSet{Any}()
|
IdSet() = IdSet{Any}()
|
||||||
|
|
||||||
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
|
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||||
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
|
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||||
|
|
||||||
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
|
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
|
||||||
|
|
||||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||||
|
|
||||||
@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
|||||||
|
|
||||||
@forward IdSet.dict Base.length
|
@forward IdSet.dict Base.length
|
||||||
|
|
||||||
Base.start(s::IdSet) = start(keys(s.dict))
|
function Base.iterate(v::IdSet, state...)
|
||||||
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
y = Base.iterate(keys(v.dict), state...)
|
||||||
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
y === nothing && return nothing
|
||||||
|
return (y[1], y[2])
|
||||||
|
end
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
function ngradient(f, xs::AbstractArray...)
|
function ngradient(f, xs::AbstractArray...)
|
||||||
grads = zeros.(xs)
|
grads = zero.(xs)
|
||||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||||
δ = sqrt(eps())
|
δ = sqrt(eps())
|
||||||
tmp = x[i]
|
tmp = x[i]
|
||||||
|
@ -115,3 +115,7 @@ end
|
|||||||
function back_(c::Call{typeof(collect)}, Δ)
|
function back_(c::Call{typeof(collect)}, Δ)
|
||||||
foreach(back, c.args[1], data(Δ))
|
foreach(back, c.args[1], data(Δ))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||||
|
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||||
|
end
|
||||||
|
@ -7,16 +7,27 @@ mapchildren(f, x) = x
|
|||||||
children(x::Tuple) = x
|
children(x::Tuple) = x
|
||||||
mapchildren(f, x::Tuple) = map(f, x)
|
mapchildren(f, x::Tuple) = map(f, x)
|
||||||
|
|
||||||
function treelike(T, fs = fieldnames(T))
|
function treelike(m::Module, T, fs = fieldnames(T))
|
||||||
@eval current_module() begin
|
@eval m begin
|
||||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function treelike(T, fs = fieldnames(T))
|
||||||
|
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
|
||||||
|
treelike(Base._current_module(), T, fs)
|
||||||
|
end
|
||||||
|
|
||||||
|
macro treelike(T, fs = nothing)
|
||||||
|
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||||
|
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||||
|
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
||||||
|
end
|
||||||
|
|
||||||
isleaf(x) = isempty(children(x))
|
isleaf(x) = isempty(children(x))
|
||||||
|
|
||||||
function mapleaves(f, x; cache = ObjectIdDict())
|
function mapleaves(f, x; cache = IdDict())
|
||||||
haskey(cache, x) && return cache[x]
|
haskey(cache, x) && return cache[x]
|
||||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||||
end
|
end
|
||||||
@ -53,7 +64,7 @@ cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
|||||||
|
|
||||||
gpu_adaptor = identity
|
gpu_adaptor = identity
|
||||||
|
|
||||||
@require CuArrays begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
global gpu_adaptor = CuArrays.cu
|
global gpu_adaptor = CuArrays.cu
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Arrays
|
# Arrays
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||||
|
|
||||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||||||
end
|
end
|
||||||
|
|
||||||
cooldown = false
|
cooldown = false
|
||||||
@schedule try
|
@async try
|
||||||
while (sleep(timeout); later != nothing)
|
while (sleep(timeout); later != nothing)
|
||||||
later()
|
later()
|
||||||
later = nothing
|
later = nothing
|
||||||
@ -145,7 +145,7 @@ function jacobian(m,x)
|
|||||||
y = m(xp)
|
y = m(xp)
|
||||||
k = length(y)
|
k = length(y)
|
||||||
n = length(x)
|
n = length(x)
|
||||||
J = Matrix{eltype(x)}(n,k)
|
J = Matrix{eltype(x)}(undef,n,k)
|
||||||
for i = 1:k
|
for i = 1:k
|
||||||
Flux.back!(y[i]) # Populate gradient accumulator
|
Flux.back!(y[i]) # Populate gradient accumulator
|
||||||
J[:,i] = xp.grad
|
J[:,i] = xp.grad
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
using Flux, Flux.Tracker, CuArrays, Test
|
||||||
using Flux: gpu
|
using Flux: gpu
|
||||||
|
|
||||||
info("Testing Flux/GPU")
|
@info "Testing Flux/GPU"
|
||||||
|
|
||||||
@testset "CuArrays" begin
|
@testset "CuArrays" begin
|
||||||
|
|
||||||
@ -14,6 +14,7 @@ cx = gpu(x)
|
|||||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||||
|
@test (cx .+ 1) isa CuArray
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
@ -25,10 +26,13 @@ x = [1,2,3]
|
|||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||||
|
|
||||||
# Fails in Pkg.test ffs
|
xs = param(rand(5,5))
|
||||||
# c = gpu(Conv((2,2),3=>4))
|
ys = Flux.onehotbatch(1:5,1:5)
|
||||||
# l = c(gpu(rand(10,10,3,2)))
|
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||||
# Flux.back!(sum(l))
|
|
||||||
|
c = gpu(Conv((2,2),3=>4))
|
||||||
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
|
Flux.back!(sum(l))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Flux, CuArrays, Base.Test
|
using Flux, CuArrays, Test
|
||||||
|
|
||||||
info("Testing Flux/CUDNN")
|
@info "Testing Flux/CUDNN"
|
||||||
|
|
||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
@testset for R in [RNN, GRU, LSTM]
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
using Flux.Data
|
using Flux.Data
|
||||||
using Base.Test
|
using Test
|
||||||
|
|
||||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||||
|
|
||||||
@test length(CMUDict.phones()) == 39
|
@test length(CMUDict.phones()) == 39
|
||||||
|
|
||||||
@test length(CMUDict.symbols()) == 84
|
@test length(CMUDict.symbols()) == 84
|
||||||
|
|
||||||
|
@test MNIST.images()[1] isa Matrix
|
||||||
|
@test MNIST.labels() isa Vector{Int64}
|
||||||
|
@ -4,7 +4,7 @@ using Flux: testmode!
|
|||||||
x = [1.,2.,3.]
|
x = [1.,2.,3.]
|
||||||
@test x == testmode!(Dropout(0.1))(x)
|
@test x == testmode!(Dropout(0.1))(x)
|
||||||
@test x == Dropout(0)(x)
|
@test x == Dropout(0)(x)
|
||||||
@test zeros(x) == Dropout(1)(x)
|
@test zero(x) == Dropout(1)(x)
|
||||||
|
|
||||||
x = rand(100)
|
x = rand(100)
|
||||||
m = Dropout(0.9)
|
m = Dropout(0.9)
|
||||||
@ -53,17 +53,17 @@ end
|
|||||||
# .1 * 4 + 0 = .4
|
# .1 * 4 + 0 = .4
|
||||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||||
|
|
||||||
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
# 2×1 Array{Float64,2}:
|
# 2×1 Array{Float64,2}:
|
||||||
# 1.14495
|
# 1.14495
|
||||||
# 1.14495
|
# 1.14495
|
||||||
@test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
@test m.σ ≈ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
|
|
||||||
testmode!(m)
|
testmode!(m)
|
||||||
@test !m.active
|
@test !m.active
|
||||||
|
|
||||||
x′ = m(x).data
|
x′ = m(x).data
|
||||||
@test x′[1] ≈ (1 - 0.3) / 1.1449489742783179
|
@test x′[1] ≈ (1 .- 0.3) / 1.1449489742783179
|
||||||
end
|
end
|
||||||
|
|
||||||
# with activation function
|
# with activation function
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using Base.Test
|
using Test
|
||||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||||
σ, binarycrossentropy, logitbinarycrossentropy
|
σ, binarycrossentropy, logitbinarycrossentropy
|
||||||
|
|
||||||
@ -42,8 +42,8 @@ const ϵ = 1e-7
|
|||||||
|
|
||||||
logŷ, y = randn(3), rand(3)
|
logŷ, y = randn(3), rand(3)
|
||||||
@testset "binarycrossentropy" begin
|
@testset "binarycrossentropy" begin
|
||||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
|
||||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "logitbinarycrossentropy" begin
|
@testset "logitbinarycrossentropy" begin
|
||||||
|
13
test/onehot.jl
Normal file
13
test/onehot.jl
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
using Flux:onecold
|
||||||
|
using Test
|
||||||
|
|
||||||
|
@testset "onecold" begin
|
||||||
|
a = [1, 2, 5, 3.]
|
||||||
|
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
|
||||||
|
labels = ['A', 'B', 'C', 'D']
|
||||||
|
|
||||||
|
@test onecold(a) == 3
|
||||||
|
@test onecold(A) == [3, 1, 4]
|
||||||
|
@test onecold(a, labels) == 'C'
|
||||||
|
@test onecold(A, labels) == ['C', 'A', 'D']
|
||||||
|
end
|
@ -1,6 +1,6 @@
|
|||||||
using Flux.Optimise
|
using Flux.Optimise
|
||||||
using Flux.Tracker
|
using Flux.Tracker
|
||||||
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||||
@ -23,7 +23,7 @@ end
|
|||||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
Iterators.repeated((), 100),
|
Iterators.repeated((), 100),
|
||||||
()->(),
|
()->(),
|
||||||
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
||||||
|
|
||||||
@test 3 < i < 50
|
@test 3 < i < 50
|
||||||
end
|
end
|
||||||
|
@ -1,17 +1,37 @@
|
|||||||
using Flux, Base.Test
|
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
|
||||||
|
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
|
||||||
|
if Base.JLOptions().check_bounds == 1
|
||||||
|
file = @__FILE__
|
||||||
|
run(```
|
||||||
|
$(Base.julia_cmd())
|
||||||
|
--color=$(Base.have_color ? "yes" : "no")
|
||||||
|
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
|
||||||
|
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
|
||||||
|
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
|
||||||
|
$(file)
|
||||||
|
```)
|
||||||
|
exit()
|
||||||
|
end
|
||||||
|
|
||||||
srand(0)
|
using Flux, Test, Random
|
||||||
|
using Random
|
||||||
|
|
||||||
|
Random.seed!(0)
|
||||||
|
|
||||||
|
# So we can use the system CuArrays
|
||||||
|
insert!(LOAD_PATH, 2, "@v#.#")
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
include("onehot.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
|
||||||
if Base.find_in_path("CuArrays") ≠ nothing
|
if Base.find_package("CuArrays") != nothing
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,22 +1,27 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux
|
||||||
|
using Flux.Tracker, Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
using Printf: @sprintf
|
||||||
|
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
|
||||||
|
using Statistics: mean, std
|
||||||
|
using Random
|
||||||
|
# using StatsBase
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||||
|
|
||||||
@testset "Tracker" begin
|
@testset "Tracker" begin
|
||||||
|
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
|
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||||
|
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||||
|
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
||||||
@test gradtest(x -> prod(x), (3,4,5))
|
@test gradtest(x -> prod(x), (3,4,5))
|
||||||
|
|
||||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||||
@ -28,7 +33,6 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
function promotiontest(f, A, B, C)
|
function promotiontest(f, A, B, C)
|
||||||
r0 = f(A, B, C)
|
r0 = f(A, B, C)
|
||||||
r1 = f(param(A), B, C)
|
r1 = f(param(A), B, C)
|
||||||
@ -48,8 +52,8 @@ function promotiontest(f, A, B, C)
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "concat" begin
|
@testset "concat" begin
|
||||||
cat1(x...) = cat(1, x...)
|
cat1(x...) = cat(x..., dims = 1)
|
||||||
cat2(x...) = cat(2, x...)
|
cat2(x...) = cat(x..., dims = 2)
|
||||||
|
|
||||||
@testset for vcatf in [vcat, cat1]
|
@testset for vcatf in [vcat, cat1]
|
||||||
@test gradtest(vcatf, rand(5), rand(3))
|
@test gradtest(vcatf, rand(5), rand(3))
|
||||||
@ -61,6 +65,7 @@ end
|
|||||||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@testset for hcatf in [hcat, cat2]
|
@testset for hcatf in [hcat, cat2]
|
||||||
@test gradtest(hcatf, rand(5), rand(5))
|
@test gradtest(hcatf, rand(5), rand(5))
|
||||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||||
@ -71,17 +76,17 @@ end
|
|||||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||||
@test gradtest(catf, rand(5))
|
@test gradtest(catf, rand(5))
|
||||||
@test gradtest(catf, rand(5)')
|
@test gradtest(catf, rand(5)')
|
||||||
@test gradtest(catf, rand(2,5))
|
@test gradtest(catf, rand(2,5))
|
||||||
@test gradtest(catf, rand(2,5,3))
|
@test gradtest(catf, rand(2,5,3))
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||||
|
|
||||||
@testset "cat($dim, ...)" for dim in 3:5
|
@testset "cat($dim, ...)" for dim in 3:5
|
||||||
catdim = (x...) -> cat(dim, x...)
|
catdim = (x...) -> cat(x..., dims = dim)
|
||||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||||
@ -89,12 +94,12 @@ end
|
|||||||
|
|
||||||
@test !isa(vcat(rand(2)), TrackedArray)
|
@test !isa(vcat(rand(2)), TrackedArray)
|
||||||
@test !isa(hcat(rand(2)), TrackedArray)
|
@test !isa(hcat(rand(2)), TrackedArray)
|
||||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
||||||
|
|
||||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||||
|
|
||||||
@testset "promotiontest" begin
|
@testset "promotiontest" begin
|
||||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||||
@ -105,16 +110,14 @@ end
|
|||||||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||||
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||||
|
|
||||||
# TODO unreliable
|
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
|
||||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
|
||||||
|
|
||||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||||
|
|
||||||
@ -124,54 +127,54 @@ end
|
|||||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||||
|
|
||||||
@test gradtest(diagm, rand(3))
|
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||||
|
|
||||||
@testset "mean" begin
|
@testset "mean" begin
|
||||||
@test gradtest(mean, rand(2, 3))
|
@test gradtest(mean, rand(2, 3))
|
||||||
|
|
||||||
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
||||||
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
||||||
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
||||||
|
|
||||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "maximum" begin
|
@testset "maximum" begin
|
||||||
@test gradtest(maximum, rand(2, 3))
|
@test gradtest(maximum, rand(2, 3))
|
||||||
|
|
||||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
||||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
||||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
||||||
|
|
||||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "minimum" begin
|
@testset "minimum" begin
|
||||||
@test gradtest(minimum, rand(2, 3))
|
@test gradtest(minimum, rand(2, 3))
|
||||||
|
|
||||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
||||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
||||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
||||||
|
|
||||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest(x -> std(x), rand(5,5))
|
@test gradtest(x -> std(x), rand(5,5))
|
||||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||||
|
|
||||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||||
@test gradtest(dot, rand(5), rand(5))
|
@test gradtest(dot, rand(5), rand(5))
|
||||||
|
|
||||||
@test gradtest(vecnorm, rand(5))
|
@test gradtest(norm, rand(5))
|
||||||
|
|
||||||
@test gradtest(rand(5)) do x
|
@test gradtest(rand(5)) do x
|
||||||
y = x.^2
|
y = x.^2
|
||||||
2y + x
|
2y + x
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||||
|
|
||||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||||
@ -211,14 +214,11 @@ end
|
|||||||
@testset "Fallbacks" begin
|
@testset "Fallbacks" begin
|
||||||
xs = param([1 2; 3 4])
|
xs = param([1 2; 3 4])
|
||||||
@test similar(xs) isa Matrix{Float64}
|
@test similar(xs) isa Matrix{Float64}
|
||||||
# Remove this test if we do LowerTriangular properly
|
|
||||||
L = LowerTriangular(xs)
|
|
||||||
@test L*L' isa Matrix{TrackedReal{Float64}}
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||||
|
|
||||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||||
|
|
||||||
b = param(rand())
|
b = param(rand())
|
||||||
Tracker.back!(b)
|
Tracker.back!(b)
|
||||||
@ -231,6 +231,11 @@ Tracker.back!(b)
|
|||||||
z = xy[1]*xy[2]
|
z = xy[1]*xy[2]
|
||||||
back!(z)
|
back!(z)
|
||||||
@test grad.((x,y)) == (3, 2)
|
@test grad.((x,y)) == (3, 2)
|
||||||
|
|
||||||
|
@test Tracker.gradient(2, 3) do x, y
|
||||||
|
xy = Tracker.collect([x, y])
|
||||||
|
xy[1]*xy[2]
|
||||||
|
end == (3, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
# Gradient Hooks
|
# Gradient Hooks
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
using Flux
|
||||||
|
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
|
||||||
|
using StatsBase: std
|
||||||
|
using Random
|
||||||
|
using Test
|
||||||
|
|
||||||
@testset "Throttle" begin
|
@testset "Throttle" begin
|
||||||
@testset "default behaviour" begin
|
@testset "default behaviour" begin
|
||||||
a = []
|
a = []
|
||||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||||
f()
|
f()
|
||||||
f()
|
f()
|
||||||
f()
|
f()
|
||||||
@ -13,7 +17,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
|||||||
|
|
||||||
@testset "leading behaviour" begin
|
@testset "leading behaviour" begin
|
||||||
a = []
|
a = []
|
||||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||||
f()
|
f()
|
||||||
@test length(a) == 1
|
@test length(a) == 1
|
||||||
f()
|
f()
|
||||||
@ -25,7 +29,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
|||||||
|
|
||||||
@testset "trailing behaviour" begin
|
@testset "trailing behaviour" begin
|
||||||
a = []
|
a = []
|
||||||
f = throttle(()->push!(a, now()), 1, leading=false, trailing=true)
|
f = throttle(()->push!(a, time()), 1, leading=false, trailing=true)
|
||||||
f()
|
f()
|
||||||
@test length(a) == 0
|
@test length(a) == 0
|
||||||
f()
|
f()
|
||||||
@ -59,7 +63,7 @@ end
|
|||||||
|
|
||||||
@testset "Initialization" begin
|
@testset "Initialization" begin
|
||||||
# Set random seed so that these tests don't fail randomly
|
# Set random seed so that these tests don't fail randomly
|
||||||
srand(0)
|
Random.seed!(0)
|
||||||
# initn() should yield a kernel with stddev ~= 1e-2
|
# initn() should yield a kernel with stddev ~= 1e-2
|
||||||
v = initn(10, 10)
|
v = initn(10, 10)
|
||||||
@test std(v) > 0.9*1e-2
|
@test std(v) > 0.9*1e-2
|
||||||
|
Loading…
Reference in New Issue
Block a user