Merge branch 'master' into curnn
This commit is contained in:
commit
a1d1930097
7
REQUIRE
7
REQUIRE
|
@ -3,8 +3,13 @@ DataFlow 0.2.1
|
|||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
ForwardDiff 0.5.0
|
||||
Requires
|
||||
Adapt
|
||||
GZip
|
||||
Colors
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
DiffRules
|
||||
SpecialFunctions
|
||||
NaNMath
|
||||
|
|
|
@ -18,7 +18,8 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
|
|||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
import .Tracker: data, value
|
||||
export Tracker
|
||||
import .Tracker: data
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Juno
|
||||
using Flux.Tracker: back!, value
|
||||
using Flux.Tracker: back!
|
||||
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(value(l)) && error("Loss is Inf")
|
||||
isnan(value(l)) && error("Loss is NaN")
|
||||
isinf(l) && error("Loss is Inf")
|
||||
isnan(l) && error("Loss is NaN")
|
||||
back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
|
|
|
@ -2,8 +2,12 @@ module Tracker
|
|||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||
|
||||
data(x) = x
|
||||
istracked(x) = false
|
||||
tracker(x) = nothing
|
||||
|
||||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||
grad(x) = grad(tracker(x))
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
|
@ -14,87 +18,59 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
|
||||
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
|
||||
data::T
|
||||
grad::T
|
||||
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
|
||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
|
||||
end
|
||||
|
||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x)
|
||||
track(f::Call, x) = Tracked(f, x)
|
||||
track(f::Call) = track(f, f())
|
||||
track(f, xs...) = track(Call(f, xs...))
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||
|
||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
isleaf(x::TrackedArray) = x.f == Call(nothing)
|
||||
|
||||
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
|
||||
param(xs::Real) = param(fill(xs))
|
||||
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
# TODO decide if keeping both data and value. The problem is TrackedScalar
|
||||
value(x) = x
|
||||
value(x::TrackedArray) = data(x)
|
||||
value(x::TrackedScalar) = data(x)[]
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
|
||||
|
||||
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
||||
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
||||
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
||||
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
end
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
istracked(x::Tracked) = true
|
||||
isleaf(x::Tracked) = x.f == Call(nothing)
|
||||
data(x::Tracked) = x.data
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
include("back.jl")
|
||||
include("lib.jl")
|
||||
include("scalar.jl")
|
||||
include("array.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
param(x::Number) = TrackedNumber(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
_graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function _graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
function graph(f, args...)
|
||||
inputs = param.(args)
|
||||
_graph(f(inputs...), inputs...)
|
||||
end
|
||||
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
||||
|
|
|
@ -1,72 +1,130 @@
|
|||
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||
end
|
||||
|
||||
unarray(xs) = xs
|
||||
unarray(xs::AbstractArray{T,0} where T) = xs[]
|
||||
tracker(x::TrackedArray) = x.tracker
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) =
|
||||
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
end
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
back!(::TrackedArray) = error("Use back!(x, Δ)")
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs.data)
|
||||
Δ′[i...] = unarray(Δ)
|
||||
Δ′[i...] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
|
||||
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
|
||||
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
|
||||
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
||||
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
|
||||
function back(::typeof(vcat), Δ, xs, ys)
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
i = Base.tail(map(_ -> :, size(Δ)))
|
||||
@back(xs, Δ[1:size(xs,1), i...])
|
||||
@back(ys, Δ[size(xs,1)+1:end, i...])
|
||||
start = 0
|
||||
for xsi in xs
|
||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||
start += size(xsi, 1)
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||
TrackedArray(Call(reshape, xs, dims...))
|
||||
track(reshape, xs, dims...)
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
||||
Base.sum(xs::TrackedScalar, dim...) = xs
|
||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
|
||||
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
function back(::typeof(dot), Δ, xs, ys)
|
||||
@back(xs, Δ.*ys)
|
||||
|
@ -85,20 +143,23 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
|||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
|
||||
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
|
||||
$f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
||||
|
||||
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -132,11 +193,11 @@ end
|
|||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool
|
||||
|
||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||
|
||||
logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
|
||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||
|
||||
|
@ -144,11 +205,11 @@ back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
|||
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
|
||||
|
||||
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
track(_conv2d, x, w, stride, padding)
|
||||
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
track(_conv2d, x, w, stride, padding)
|
||||
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
track(_conv2d, x, w, stride, padding)
|
||||
|
||||
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
|
||||
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||
|
@ -158,7 +219,7 @@ end
|
|||
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
|
||||
|
||||
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
|
||||
TrackedArray(Call(_pool, x, window, padding, mode))
|
||||
track(_pool, x, window, padding, mode)
|
||||
|
||||
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
|
||||
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
|
||||
|
@ -167,7 +228,8 @@ back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
|
|||
|
||||
using ForwardDiff: Dual, partials
|
||||
|
||||
struct Broadcasted{T}
|
||||
struct Broadcasted{F,T}
|
||||
f::F
|
||||
data::T
|
||||
end
|
||||
|
||||
|
@ -175,23 +237,24 @@ end
|
|||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
||||
dualify(xs::TrackedNumber, ps) = Dual(data(xs), ps)
|
||||
|
||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||
# Works around a 0.6 type inference issue
|
||||
b = Broadcasted(out)
|
||||
TrackedArray(Call(b, args...), b())
|
||||
b = Broadcasted(f, out)
|
||||
track(Call(b, args...), b())
|
||||
end
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
||||
|
||||
unbroadcast(x, Δ) =
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
|
||||
function getpartial(Δ, x, i)
|
||||
@inbounds p = getindex(partials(x), i)
|
||||
return Δ * p
|
|
@ -1,25 +1,33 @@
|
|||
scan(x) = nothing
|
||||
init_grad(x) = zero(x)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::TrackedArray)
|
||||
function scan(x::Tracked)
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
else
|
||||
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
function scan(x)
|
||||
istracked(x) && scan(tracker(x))
|
||||
return
|
||||
end
|
||||
|
||||
back_(f, y, args...) = back(f, args...)
|
||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||
back_(::Call{Void}, y, Δ) = nothing
|
||||
|
||||
function back(x::TrackedArray, Δ)
|
||||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad .+= Δ
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
ref == 0 && back_(x.f, x.data, x.grad)
|
||||
else
|
||||
ref == 0 && back_(x.f, x.data, Δ)
|
||||
|
@ -27,6 +35,9 @@ function back(x::TrackedArray, Δ)
|
|||
return
|
||||
end
|
||||
|
||||
back(x, Δ) = back(tracker(x), Δ)
|
||||
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
|
||||
macro back(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
|
@ -39,9 +50,9 @@ end
|
|||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
|
||||
function back!(x::TrackedArray, Δ)
|
||||
function back!(x::Tracked, Δ)
|
||||
scan(x)
|
||||
back(x, Δ)
|
||||
end
|
||||
|
||||
back!(x::TrackedScalar) = back!(x, 1)
|
||||
back!(x, Δ) = back!(tracker(x), Δ)
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
struct TrackedNumber{T<:Number} <: Number
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x))
|
||||
|
||||
tracker(x::TrackedNumber) = x.tracker
|
||||
|
||||
track(f::Call, x::Number) = TrackedNumber(Tracked(f, x))
|
||||
|
||||
back!(x::TrackedNumber) = back!(x, 1)
|
||||
|
||||
function Base.show(io::IO, x::TrackedNumber)
|
||||
show(io, data(x))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber) where T =
|
||||
TrackedNumber(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
|
||||
Base.convert(::Type{TrackedNumber{T}}, x::Number) where T = TrackedNumber(convert(T, x))
|
||||
|
||||
Base.isless(x::TrackedNumber, y::Number) = isless(data(x), y)
|
||||
Base.isless(x::Number, y::TrackedNumber) = isless(x, data(y))
|
||||
Base.isless(x::TrackedNumber, y::TrackedNumber) = isless(data(x), data(y))
|
||||
|
||||
Base.:(==)(x::TrackedNumber, y::Number) = data(x) == y
|
||||
Base.:(==)(x::Number, y::TrackedNumber) = x == data(y)
|
||||
Base.:(==)(x::TrackedNumber, y::TrackedNumber) = data(x) == data(y)
|
||||
|
||||
for f in :[isinf, isnan, isfinite].args
|
||||
@eval Base.$f(x::TrackedNumber) = Base.$f(data(x))
|
||||
end
|
||||
|
||||
Base.Printf.fix_dec(x::TrackedNumber, n::Int) = Base.Printf.fix_dec(data(x), n)
|
||||
|
||||
Base.promote_rule(::Type{TrackedNumber{S}},::Type{T}) where {S,T} =
|
||||
TrackedNumber{promote_type(S,T)}
|
||||
|
||||
using DiffRules, SpecialFunctions, NaNMath
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 1 || continue
|
||||
@eval begin
|
||||
$M.$f(a::TrackedNumber) = track($M.$f, a)
|
||||
back(::typeof($M.$f), Δ::Number, a::TrackedNumber) =
|
||||
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
|
||||
end
|
||||
end
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 2 || continue
|
||||
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
||||
@eval begin
|
||||
$M.$f(a::TrackedNumber, b::TrackedNumber) = track($M.$f, a, b)
|
||||
$M.$f(a::TrackedNumber, b::Number) = track($M.$f, a, b)
|
||||
$M.$f(a::Number, b::TrackedNumber) = track($M.$f, a, b)
|
||||
function back(::typeof($M.$f), Δ::Number, a::Number, b::Number)
|
||||
@back(a, Δ * $da)
|
||||
@back(b, Δ * $db)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
tracker(xs::TrackedTuple) = xs.tracker
|
||||
|
||||
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
|
||||
back(::typeof(getindex), Δ, t, i) =
|
||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
|
@ -2,7 +2,7 @@ using Flux.Tracker, Base.Test, NNlib
|
|||
using Flux.Tracker: gradcheck
|
||||
using NNlib
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
||||
@testset "Tracker" begin
|
||||
|
@ -13,14 +13,12 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
|
||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
||||
## uncomment the following test when logsoftmax has been added into NNlib.jl
|
||||
#@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
||||
#@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
@ -28,7 +26,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@test gradtest(vcat, rand(5), rand(3))
|
||||
@test gradtest(vcat, rand(2,3), rand(3,3))
|
||||
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@test gradtest(diagm, rand(3))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
|
Loading…
Reference in New Issue