forward mode

This commit is contained in:
Mike J Innes 2019-01-25 16:13:34 +00:00
parent 962ce88c0d
commit 58ac415f6b
3 changed files with 67 additions and 2 deletions

View File

@ -66,13 +66,15 @@ include("back.jl")
include("numeric.jl")
include("lib/real.jl")
include("lib/array.jl")
include("forward.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`."""
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))

53
src/tracker/forward.jl Normal file
View File

@ -0,0 +1,53 @@
using ForwardDiff
seed(x::Real, ::Val) = Dual(x, true)
function seed(x, ::Val{N}, offset = 0) where N
map(x, reshape(1:length(x), size(x))) do x, i
Dual(x, ntuple(j -> j+offset == i, Val(N)))
end
end
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
J = similar(xs, V, N, length(xs))
for i = 1:length(xs), j = 1:N
J[j, i] = xs[i].partials.values[j]
end
return map(x -> x.value, xs), J
end
function forward_jacobian(f, x, ::Val{N}) where N
y, _J = extract(f(seed(x, Val(N))))
J = similar(_J, length(x), length(y))
J[1:N,:] = _J
offset = 0
while offset + N < length(x)
offset += N
_, _J = extract(f(seed(x, Val(N), offset)))
range = (1+offset):min(N+offset,length(x))
J[range,:] = @view _J[range.-offset,:]
end
return y, J
end
function forward_jacobian(f, x)
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
forward_jacobian(f, x, Val(length(x)))
else
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
end
end
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
vec_scalar(x) = vec(x)
vec_scalar(x::Real) = [x]
reshape_scalar(x, y) = reshape(y, size(x))
reshape_scalar(x::Real, y) = y[]
@grad function forwarddiff(f, x)
y, J = forward_jacobian(f, data(x))
return y, -> (nothing, reshape_scalar(x, J*vec_scalar()))
end

View File

@ -1,6 +1,6 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
@ -305,4 +305,14 @@ end
@test gs[W] == dW
end
@testset "Forward" begin
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
reshape(ones(25), :, 1)
@test gradient([2, 3]) do x
forwarddiff(x) do x
x[1]*x[2]
end
end == ([3, 2],)
end
end #testset