Merge branch 'master' into cudnn_batchnorm
This commit is contained in:
commit
3b448ce1ac
|
@ -4,11 +4,11 @@ os:
|
|||
- linux
|
||||
# - osx
|
||||
julia:
|
||||
- 0.6
|
||||
- 0.7
|
||||
# uncomment the following lines to override the default test script
|
||||
script:
|
||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
# script:
|
||||
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
after_success:
|
||||
- julia -e 'Pkg.add("Documenter")'
|
||||
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||
|
|
3
REQUIRE
3
REQUIRE
|
@ -1,4 +1,4 @@
|
|||
julia 0.6.0
|
||||
julia 0.7-
|
||||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
|
@ -9,6 +9,7 @@ Colors
|
|||
ZipFile
|
||||
AbstractTrees
|
||||
Reexport
|
||||
StatsBase
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
|
|
|
@ -134,7 +134,7 @@ All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around
|
|||
|
||||
```julia
|
||||
julia> x.tracker
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
```
|
||||
|
||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
||||
|
|
|
@ -211,7 +211,7 @@ m(5) # => 26
|
|||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||
|
||||
```julia
|
||||
Flux.treelike(Affine)
|
||||
Flux.@treelike Affine
|
||||
```
|
||||
|
||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||
|
|
|
@ -4,7 +4,7 @@ module Flux
|
|||
|
||||
# Zero Flux Given
|
||||
|
||||
using Juno, Requires, Reexport
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
|
@ -37,6 +37,6 @@ include("layers/normalise.jl")
|
|||
|
||||
include("data/Data.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -3,23 +3,23 @@ using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
|||
import ..Flux: data
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Void}
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
end
|
||||
|
||||
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
|
||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
|
||||
|
||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||
d = [C_NULL]
|
||||
s = Csize_t[0]
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||
finalizer(desc, x ->
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x))
|
||||
return desc
|
||||
end
|
||||
|
||||
|
|
|
@ -24,25 +24,25 @@ end
|
|||
|
||||
function phones()
|
||||
load()
|
||||
Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")),
|
||||
"\n", keep = false), "\t")))
|
||||
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
||||
"\n", keepempty = false), "\t")))
|
||||
end
|
||||
|
||||
function symbols()
|
||||
load()
|
||||
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
|
||||
"\n", keep = false))
|
||||
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
||||
"\n", keepempty = false))
|
||||
end
|
||||
|
||||
function rawdict()
|
||||
load()
|
||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
|
||||
filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n"))))
|
||||
end
|
||||
|
||||
validword(s) = ismatch(r"^[\w\-\.]+$", s)
|
||||
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
||||
|
||||
cmudict() = filter((s, ps) -> validword(s), rawdict())
|
||||
cmudict() = filter(p -> validword(p.first), rawdict())
|
||||
|
||||
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ m(x) == m[2](m[1](x))
|
|||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||
`m[1:3](x)` will calculate the output of the first three layers.
|
||||
"""
|
||||
type Chain
|
||||
struct Chain
|
||||
layers::Vector{Any}
|
||||
Chain(xs...) = new([xs...])
|
||||
end
|
||||
|
@ -28,7 +28,7 @@ children(c::Chain) = c.layers
|
|||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
|
@ -38,9 +38,6 @@ function Base.show(io::IO, c::Chain)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
# Seem to need this for `accumulate`; try removing on 0.7
|
||||
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
||||
|
||||
"""
|
||||
|
@ -76,7 +73,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
|
|||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||
end
|
||||
|
||||
treelike(Dense)
|
||||
@treelike Dense
|
||||
|
||||
function (a::Dense)(x)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
|
@ -107,7 +104,7 @@ end
|
|||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
|
||||
treelike(Diagonal)
|
||||
@treelike Diagonal
|
||||
|
||||
function (a::Diagonal)(x)
|
||||
α, β = a.α, a.β
|
||||
|
|
|
@ -32,10 +32,10 @@ Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
Conv(param(init(k..., ch...)), param(zero(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
Flux.treelike(Conv)
|
||||
@treelike Conv
|
||||
|
||||
function (c::Conv)(x)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
|
|
|
@ -57,7 +57,7 @@ end
|
|||
LayerNorm(h::Integer) =
|
||||
LayerNorm(Diagonal(h))
|
||||
|
||||
treelike(LayerNorm)
|
||||
@treelike LayerNorm
|
||||
|
||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||
|
||||
|
@ -112,10 +112,10 @@ end
|
|||
|
||||
# NOTE: Keeping the ϵ smaller than 1e-5 is not supported by CUDNN
|
||||
function BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = x->zeros(Float32,x),
|
||||
initγ = x->ones(Float32,x),
|
||||
initβ = (i) -> zeros(i),
|
||||
initγ = (i) -> ones(i),
|
||||
ϵ = 1f-5,
|
||||
momentum = 0.1f0)
|
||||
momentum = 0.1)
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(Float32, chs), ones(Float32, chs), ϵ, momentum, true)
|
||||
end
|
||||
|
@ -140,10 +140,9 @@ function (BN::BatchNorm)(x)
|
|||
σ² = sum((x.-μ).^2, axes) ./ m
|
||||
|
||||
# update moving mean/std
|
||||
mtm = convert(T, BN.momentum)
|
||||
|
||||
BN.μ = ((1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))) |> data
|
||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* squeeze(data(σ²), (axes...))*m/(m-1)) |> data
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = ((1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...)))
|
||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* squeeze(data(σ²), (axes...)) .* m ./ (m - 1))
|
||||
end
|
||||
|
||||
ϵ = convert(T, BN.ϵ)
|
||||
|
|
|
@ -38,7 +38,7 @@ function (m::Recur)(xs...)
|
|||
return y
|
||||
end
|
||||
|
||||
treelike(Recur, (:cell, :init))
|
||||
@treelike Recur cell, init
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
|
@ -94,7 +94,7 @@ end
|
|||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
treelike(RNNCell)
|
||||
@treelike RNNCell
|
||||
|
||||
function Base.show(io::IO, l::RNNCell)
|
||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||
|
@ -122,7 +122,7 @@ end
|
|||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)),
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.b.data[gate(out, 2)] = 1
|
||||
return cell
|
||||
|
@ -143,7 +143,7 @@ end
|
|||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
treelike(LSTMCell)
|
||||
@treelike LSTMCell
|
||||
|
||||
Base.show(io::IO, l::LSTMCell) =
|
||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||
|
@ -170,7 +170,7 @@ end
|
|||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(initn(out)))
|
||||
param(zero(out*3)), param(initn(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
@ -178,13 +178,13 @@ function (m::GRUCell)(h, x)
|
|||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
||||
h′ = (1.-z).*h̃ .+ z.*h
|
||||
h′ = (1 .- z).*h̃ .+ z.*h
|
||||
return h′, h′
|
||||
end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
||||
treelike(GRUCell)
|
||||
@treelike GRUCell
|
||||
|
||||
Base.show(io::IO, l::GRUCell) =
|
||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||
|
|
|
@ -32,20 +32,20 @@ import Adapt.adapt
|
|||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
@require CuArrays begin
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import CuArrays: CuArray, cudaconvert
|
||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
function onehot(l, labels)
|
||||
i = findfirst(labels, l)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
function onehot(l, labels, unk)
|
||||
i = findfirst(labels, l)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || return onehot(unk, labels)
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
|
|
@ -9,7 +9,7 @@ struct Param{T}
|
|||
Δ::T
|
||||
end
|
||||
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zero(x))
|
||||
|
||||
include("optimisers.jl")
|
||||
include("interface.jl")
|
||||
|
|
|
@ -14,7 +14,7 @@ function descentweightdecay(p::Param, η::Real, γ::Real)
|
|||
end
|
||||
|
||||
function momentum(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
@. v = ρ * v - η * p.Δ
|
||||
@. p.Δ = -v
|
||||
|
@ -23,7 +23,7 @@ end
|
|||
|
||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
||||
function nesterov(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
||||
@. v = ρ*v - η*p.Δ
|
||||
|
@ -32,7 +32,7 @@ function nesterov(p::Param, ρ, η)
|
|||
end
|
||||
|
||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x)
|
||||
acc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
|
@ -40,7 +40,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x) .+ ϵ
|
||||
acc = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. acc += p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
|
@ -48,8 +48,8 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x)
|
||||
Δacc = zeros(p.x)
|
||||
acc = zero(p.x)
|
||||
Δacc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||
|
@ -58,8 +58,8 @@ function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
@ -71,8 +71,8 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||
end
|
||||
|
||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
ut = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
ut = zero(p.x)
|
||||
β1p = β1
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
@ -83,9 +83,9 @@ function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999,
|
|||
end
|
||||
|
||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x) .+ ϵ
|
||||
v̂t = zeros(p.x) .+ ϵ
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x) .+ ϵ
|
||||
v̂t = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||
|
@ -95,8 +95,8 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||
end
|
||||
|
||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
|
|
@ -12,7 +12,7 @@ tracker(x) = nothing
|
|||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
grad(x) = grad(tracker(x))
|
||||
grad(::Void) = nothing
|
||||
grad(::Nothing) = nothing
|
||||
data(x) = x
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
|
@ -35,7 +35,7 @@ mutable struct Tracked{T}
|
|||
grad::T
|
||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
|
||||
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||
end
|
||||
|
||||
istracked(x::Tracked) = true
|
||||
|
@ -46,14 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
|||
|
||||
function _forward end
|
||||
|
||||
function track(f::F, xs...) where F
|
||||
y, back = _forward(f, xs...)
|
||||
ts = map(tracker, xs)
|
||||
c = Call(back, ts)
|
||||
track(c, y)
|
||||
end
|
||||
|
||||
function track_kw(f::F, xs...; kw...) where F
|
||||
function track(f::F, xs...; kw...) where F
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
@ -87,7 +80,7 @@ Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
|||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||
|
||||
"""
|
||||
checkpoint(f, args...)
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
import Base: *, ==
|
||||
|
||||
import LinearAlgebra
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
data::A
|
||||
|
@ -21,24 +27,20 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
function Base.summary(io::IO, x::TrackedArray)
|
||||
print(io, "Tracked ")
|
||||
summary(io, data(x))
|
||||
end
|
||||
|
||||
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
|
@ -58,9 +60,9 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
||||
x::TrackedArray == y = data(x) == y
|
||||
y == x::TrackedArray = y == data(x)
|
||||
x::TrackedArray == y::TrackedArray = data(x) == data(y)
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
|
@ -79,29 +81,12 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
|||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
|
||||
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
|
||||
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
@grad function repmat(xs, m, n = 1)
|
||||
repmat(data(xs), m, n), function (Δ)
|
||||
Δ′ = similar(xs)
|
||||
S = size(xs)
|
||||
for (i,v) in enumerate(data(Δ))
|
||||
d1 = divrem(i-1, S[1]*m)
|
||||
x = d1[2] % S[1]+1
|
||||
y = d1[1] % S[2]+1
|
||||
Δ′[x, y] += v
|
||||
end
|
||||
return (nobacksies(:repmat, Δ′), nothing, nothing)
|
||||
end
|
||||
end
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
|
@ -109,7 +94,7 @@ Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
|||
S = size(xs)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
|
||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
|
@ -119,8 +104,8 @@ Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
|||
end
|
||||
end
|
||||
|
||||
|
||||
for f in [:vcat, :hcat]
|
||||
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||
@eval begin
|
||||
# This section is a bit of a hack since julia doesn't have a standardised
|
||||
# promotion mechanism for concatenation yet
|
||||
|
@ -129,18 +114,18 @@ for f in [:vcat, :hcat]
|
|||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||
# TrackedArray anywhere among the arguments This works as long as base has
|
||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
||||
Base.$f(a::$UArray...) = track($f, a...)
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# first
|
||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# second
|
||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
||||
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
||||
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
||||
c::$UArray...) =
|
||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||
end
|
||||
end
|
||||
|
@ -175,21 +160,23 @@ end
|
|||
end
|
||||
end
|
||||
|
||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
|
||||
@grad function cat(dims, Xs...)
|
||||
cat(dims, data.(Xs)...), function (Δ)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
start = ntuple(i -> 0, Val(ndims(Δ)))
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||
start = start .+ till_xs
|
||||
d
|
||||
end for xs in Xs]
|
||||
return (nothing, Δs...,)
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -218,98 +205,95 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, )
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
||||
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
||||
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
import LinearAlgebra: dot
|
||||
|
||||
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||
|
||||
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
||||
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
||||
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||
|
||||
@grad function maximum(xs, r...)
|
||||
maximum(data(xs), r...), function (Δ)
|
||||
@grad function maximum(xs; dims = dims)
|
||||
maximum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), r...)
|
||||
_, i = findmax(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:maximum, Δ′),)
|
||||
end
|
||||
end
|
||||
@grad function minimum(xs, r...)
|
||||
minimum(data(xs), r...), function (Δ)
|
||||
|
||||
@grad function minimum(xs; dims = dims)
|
||||
minimum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), r...)
|
||||
_, i = findmin(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:minimum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
|
||||
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
||||
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
||||
|
||||
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
||||
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
||||
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
||||
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
# @grad function (a::AbstractMatrix * b::AbstractVecOrMat)
|
||||
# # @show size(a) size(b)
|
||||
# data(a)*data(b), function (Δ)
|
||||
# @show size(Δ) size(b) size(Δ*transpose(b)) size(Δ*transpose(data(b)))
|
||||
# @show typeof(Δ) typeof(b)
|
||||
# (Δ * transpose(b), transpose(a) * Δ)
|
||||
# end
|
||||
# end
|
||||
|
||||
# NNlib
|
||||
|
||||
|
@ -324,9 +308,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
|
@ -334,14 +318,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
|||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
|
@ -352,13 +336,16 @@ end
|
|||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
_size(x::AbstractArray) = size(x)
|
||||
_size(x) = ()
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||
|
||||
unbroadcast(x::Tuple, Δ) =
|
||||
x == size(Δ) ? Δ :
|
||||
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
|
||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||
|
||||
|
@ -368,14 +355,14 @@ function getpartial(Δ, x, i)
|
|||
end
|
||||
|
||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
sizes = size.(args)
|
||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
sizes = _size.(args)
|
||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N)))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
y = value.(out)
|
||||
back = function (Δ_)
|
||||
Δ = data(Δ_)
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
|
||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||
nobacksies(:broadcast, dxs)
|
||||
end
|
||||
|
@ -383,14 +370,14 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
|
|||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
||||
using Base.Broadcast: BroadcastStyle
|
||||
|
||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
|
||||
struct TrackedStyle <: BroadcastStyle end
|
||||
|
||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||
|
||||
function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
|
||||
bc = Broadcast.flatten(bc)
|
||||
∇broadcast(bc.f, bc.args...)
|
||||
end
|
||||
|
|
|
@ -26,7 +26,7 @@ function back_(c::Call, Δ)
|
|||
foreach(back, c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Void}, Δ) = nothing
|
||||
back_(::Call{Nothing}, Δ) = nothing
|
||||
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
@ -47,7 +47,7 @@ function back(x::Tracked, Δ)
|
|||
return
|
||||
end
|
||||
|
||||
back(::Void, _) = return
|
||||
back(::Nothing, _) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
|
@ -79,14 +79,14 @@ function Base.show(io::IO, ps::Params)
|
|||
end
|
||||
|
||||
struct Grads
|
||||
grads::ObjectIdDict
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(ObjectIdDict())
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
function Base.getindex(g::Grads, x)
|
||||
|
@ -96,7 +96,7 @@ end
|
|||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
|
@ -105,7 +105,7 @@ function back_(g::Grads, c::Call, Δ)
|
|||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
||||
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||
|
||||
function back(g::Grads, x::Tracked, Δ)
|
||||
x.isleaf && (accum!(g, x, Δ); return)
|
||||
|
@ -119,7 +119,7 @@ function back(g::Grads, x::Tracked, Δ)
|
|||
return
|
||||
end
|
||||
|
||||
back(::Grads, ::Void, _) = return
|
||||
back(::Grads, ::Nothing, _) = return
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
struct IdSet{T} <: AbstractSet{T}
|
||||
dict::ObjectIdDict
|
||||
IdSet{T}() where T = new(ObjectIdDict())
|
||||
dict::IdDict{T,Nothing}
|
||||
IdSet{T}() where T = new(IdDict{T,Nothing}())
|
||||
end
|
||||
|
||||
Base.eltype{T}(::IdSet{T}) = T
|
||||
Base.eltype(::IdSet{T}) where T = T
|
||||
|
||||
IdSet() = IdSet{Any}()
|
||||
|
||||
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
|
||||
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
|
||||
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
||||
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
|
||||
(::Type{IdSet{T}})(xs) where T = push!(IdSet{T}(), xs...)
|
||||
|
||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zeros.(xs)
|
||||
grads = zero.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp = x[i]
|
||||
|
|
|
@ -115,3 +115,7 @@ end
|
|||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach(back, c.args[1], data(Δ))
|
||||
end
|
||||
|
||||
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||
end
|
||||
|
|
|
@ -7,16 +7,27 @@ mapchildren(f, x) = x
|
|||
children(x::Tuple) = x
|
||||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
@eval current_module() begin
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||
end
|
||||
end
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
|
||||
treelike(Base._current_module(), T, fs)
|
||||
end
|
||||
|
||||
macro treelike(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
function mapleaves(f, x; cache = ObjectIdDict())
|
||||
function mapleaves(f, x; cache = IdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
@ -53,7 +64,7 @@ cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
|||
|
||||
gpu_adaptor = identity
|
||||
|
||||
@require CuArrays begin
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
global gpu_adaptor = CuArrays.cu
|
||||
end
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
|
@ -119,7 +119,7 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||
end
|
||||
|
||||
cooldown = false
|
||||
@schedule try
|
||||
@async try
|
||||
while (sleep(timeout); later != nothing)
|
||||
later()
|
||||
later = nothing
|
||||
|
@ -145,7 +145,7 @@ function jacobian(m,x)
|
|||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(n,k)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i]) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux: gpu
|
||||
|
||||
info("Testing Flux/GPU")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux.Tracker: TrackedArray, data
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux.Data
|
||||
using Base.Test
|
||||
using Test
|
||||
|
||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ using Flux.Tracker: data
|
|||
x = [1.,2.,3.]
|
||||
@test x == testmode!(Dropout(0.1))(x)
|
||||
@test x == Dropout(0)(x)
|
||||
@test zeros(x) == Dropout(1)(x)
|
||||
@test zero(x) == Dropout(1)(x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Base.Test
|
||||
using Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux, Base.Test
|
||||
using Flux, Test, Random
|
||||
|
||||
srand(0)
|
||||
|
||||
|
@ -11,8 +11,8 @@ include("layers/stateless.jl")
|
|||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
# if Base.find_in_path("CuArrays") ≠ nothing
|
||||
# include("cuda/cuda.jl")
|
||||
# end
|
||||
|
||||
end
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
@ -12,11 +17,14 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
|
@ -48,8 +56,8 @@ function promotiontest(f, A, B, C)
|
|||
end
|
||||
|
||||
@testset "concat" begin
|
||||
cat1(x...) = cat(1, x...)
|
||||
cat2(x...) = cat(2, x...)
|
||||
cat1(x...) = cat(x..., dims = 1)
|
||||
cat2(x...) = cat(x..., dims = 2)
|
||||
|
||||
@testset for vcatf in [vcat, cat1]
|
||||
@test gradtest(vcatf, rand(5), rand(3))
|
||||
|
@ -71,17 +79,17 @@ end
|
|||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
@test gradtest(catf, rand(5))
|
||||
@test gradtest(catf, rand(5)')
|
||||
@test gradtest(catf, rand(2,5))
|
||||
@test gradtest(catf, rand(2,5,3))
|
||||
end
|
||||
|
||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 3:5
|
||||
catdim = (x...) -> cat(dim, x...)
|
||||
catdim = (x...) -> cat(x..., dims = dim)
|
||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||
|
@ -91,10 +99,10 @@ end
|
|||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||
|
@ -105,16 +113,12 @@ end
|
|||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
# TODO unreliable
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
|
@ -129,49 +133,49 @@ end
|
|||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
||||
@test gradtest(vecnorm, rand(5))
|
||||
@test gradtest(norm, rand(5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
@ -211,14 +215,11 @@ end
|
|||
@testset "Fallbacks" begin
|
||||
xs = param([1 2; 3 4])
|
||||
@test similar(xs) isa Matrix{Float64}
|
||||
# Remove this test if we do LowerTriangular properly
|
||||
L = LowerTriangular(xs)
|
||||
@test L*L' isa Matrix{TrackedReal{Float64}}
|
||||
end
|
||||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
|
@ -231,6 +232,11 @@ Tracker.back!(b)
|
|||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
|
||||
@test Tracker.gradient(2, 3) do x, y
|
||||
xy = Tracker.collect([x, y])
|
||||
xy[1]*xy[2]
|
||||
end == (3, 2)
|
||||
end
|
||||
|
||||
# Gradient Hooks
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||
using StatsBase: std
|
||||
using Dates
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
|
|
Loading…
Reference in New Issue