Flux.jl/src/tracker/lib.jl

189 lines
7.1 KiB
Julia
Raw Normal View History

2017-08-22 11:24:08 +00:00
toarray(xs::AbstractArray, ys::AbstractArray) = ys
2017-08-22 14:12:12 +00:00
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
2017-08-22 11:24:08 +00:00
2017-09-03 21:10:23 +00:00
unarray(xs) = xs
unarray(xs::AbstractArray{T,0} where T) = xs[]
2017-08-22 11:24:08 +00:00
Base.getindex(xs::TrackedArray, i...) =
2017-09-07 01:21:35 +00:00
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
2017-08-19 09:14:50 +00:00
2017-09-07 03:09:32 +00:00
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
2017-09-07 01:21:35 +00:00
Δ′ = zeros(xs.data)
2017-09-03 21:10:23 +00:00
Δ′[i...] = unarray(Δ)
2017-09-07 03:09:32 +00:00
@back(xs, Δ′)
2017-08-19 15:02:19 +00:00
end
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
2017-09-07 03:09:32 +00:00
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
2017-08-23 16:50:43 +00:00
2017-09-01 15:42:18 +00:00
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
2017-09-07 03:09:32 +00:00
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
2017-09-03 21:10:23 +00:00
2017-09-05 06:12:53 +00:00
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
2017-09-03 06:12:44 +00:00
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
2017-09-05 23:25:42 +00:00
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
2017-09-05 06:28:11 +00:00
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
2017-09-05 06:11:28 +00:00
2017-09-07 03:09:32 +00:00
function back(::typeof(vcat), Δ, xs, ys)
2017-09-05 06:11:28 +00:00
i = Base.tail(map(_ -> :, size(Δ)))
2017-09-07 03:09:32 +00:00
@back(xs, Δ[1:size(xs,1), i...])
@back(ys, Δ[size(xs,1)+1:end, i...])
2017-09-05 06:11:28 +00:00
end
2017-08-22 11:24:08 +00:00
# Reductions
2017-08-22 14:12:12 +00:00
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
2017-09-07 01:21:35 +00:00
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
2017-08-23 16:50:43 +00:00
Base.sum(xs::TrackedScalar, dim...) = xs
2017-08-22 11:24:08 +00:00
2017-09-07 03:09:32 +00:00
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
2017-08-22 11:24:08 +00:00
2017-09-07 01:21:35 +00:00
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
2017-09-02 03:33:05 +00:00
2017-10-31 10:41:44 +00:00
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
2017-12-12 17:23:15 +00:00
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
function back(::typeof(dot), Δ, xs, ys)
@back(xs, Δ.*ys)
@back(ys, Δ.*xs)
end
2017-11-21 16:04:04 +00:00
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
2017-10-31 10:41:44 +00:00
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
2017-08-22 11:24:08 +00:00
# BLAS
2017-12-12 17:07:39 +00:00
for f in :[*, Ac_mul_B, A_mul_Bc].args
2017-11-08 22:00:19 +00:00
@eval begin
import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
end
end
2017-08-20 12:48:43 +00:00
2017-09-07 03:09:32 +00:00
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(a, A_mul_Bt(Δ, data(b)))
@back(b, At_mul_B(data(a), Δ))
2017-08-19 15:02:19 +00:00
end
2017-11-08 22:00:19 +00:00
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))')
2017-12-12 17:07:39 +00:00
@back(b, data(a)*Δ)
end
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, Δ * data(b))
@back(b, At_mul_B(data(a), Δ)')
2017-11-08 22:00:19 +00:00
end
2017-11-07 19:34:27 +00:00
# Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W)
W.grad .+= Δ .* data(x).'
else
back(W, A_mul_Bt(Δ, data(x)))
end
@back(x, At_mul_B(data(W), Δ))
end
2017-08-23 01:03:17 +00:00
# NNlib
2017-12-14 18:48:38 +00:00
using NNlib
import NNlib: softmax, ∇softmax, conv2d
2017-08-23 01:03:17 +00:00
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
2017-09-07 03:09:32 +00:00
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
2017-08-23 01:03:17 +00:00
2017-12-14 18:48:38 +00:00
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w))
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w))
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w))
function back(::typeof(conv2d), Δ, x, w)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ))
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ))
end
2017-08-19 15:02:19 +00:00
# Broadcasting
using ForwardDiff: Dual, partials
struct Broadcasted{T}
data::T
end
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
dualify(xs, n) = xs
2017-08-21 15:35:39 +00:00
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
2017-08-19 15:02:19 +00:00
function tracked_broadcast(f, args::Vararg{Any,N}) where N
2017-08-20 12:35:20 +00:00
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
# Works around a 0.6 type inference issue
b = Broadcasted(broadcast(f, dargs...))
TrackedArray(Call(b, args...), b())
2017-08-19 15:02:19 +00:00
end
2017-08-27 08:49:42 +00:00
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
2017-08-22 23:25:19 +00:00
unbroadcast(x, Δ) =
size(x) == size(Δ) ? Δ :
2017-08-27 08:49:42 +00:00
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
2017-08-22 23:25:19 +00:00
2017-08-28 00:40:59 +00:00
function getpartial(Δ, x, i)
@inbounds p = getindex(partials(x), i)
return Δ * p
end
2017-09-07 03:09:32 +00:00
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
2017-08-28 00:40:59 +00:00
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
2017-09-07 03:09:32 +00:00
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
2017-08-19 09:14:50 +00:00
end
2017-08-19 15:02:19 +00:00
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
2017-08-23 16:50:43 +00:00
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
2017-08-19 15:02:19 +00:00
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
2017-08-23 16:21:02 +00:00
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
2017-08-19 15:02:19 +00:00
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)