forward mode
This commit is contained in:
parent
962ce88c0d
commit
58ac415f6b
|
@ -66,13 +66,15 @@ include("back.jl")
|
|||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
include("lib/array.jl")
|
||||
include("forward.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`."""
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
using ForwardDiff
|
||||
|
||||
seed(x::Real, ::Val) = Dual(x, true)
|
||||
|
||||
function seed(x, ::Val{N}, offset = 0) where N
|
||||
map(x, reshape(1:length(x), size(x))) do x, i
|
||||
Dual(x, ntuple(j -> j+offset == i, Val(N)))
|
||||
end
|
||||
end
|
||||
|
||||
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
|
||||
|
||||
function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
|
||||
J = similar(xs, V, N, length(xs))
|
||||
for i = 1:length(xs), j = 1:N
|
||||
J[j, i] = xs[i].partials.values[j]
|
||||
end
|
||||
return map(x -> x.value, xs), J
|
||||
end
|
||||
|
||||
function forward_jacobian(f, x, ::Val{N}) where N
|
||||
y, _J = extract(f(seed(x, Val(N))))
|
||||
J = similar(_J, length(x), length(y))
|
||||
J[1:N,:] = _J
|
||||
offset = 0
|
||||
while offset + N < length(x)
|
||||
offset += N
|
||||
_, _J = extract(f(seed(x, Val(N), offset)))
|
||||
range = (1+offset):min(N+offset,length(x))
|
||||
J[range,:] = @view _J[range.-offset,:]
|
||||
end
|
||||
return y, J
|
||||
end
|
||||
|
||||
function forward_jacobian(f, x)
|
||||
if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
|
||||
forward_jacobian(f, x, Val(length(x)))
|
||||
else
|
||||
forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
|
||||
end
|
||||
end
|
||||
|
||||
forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x)
|
||||
|
||||
vec_scalar(x) = vec(x)
|
||||
vec_scalar(x::Real) = [x]
|
||||
reshape_scalar(x, y) = reshape(y, size(x))
|
||||
reshape_scalar(x::Real, y) = y[]
|
||||
|
||||
@grad function forwarddiff(f, x)
|
||||
y, J = forward_jacobian(f, data(x))
|
||||
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
|
||||
end
|
|
@ -1,6 +1,6 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
|
||||
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
|
||||
using NNlib: conv, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
|
@ -305,4 +305,14 @@ end
|
|||
@test gs[W] == dW
|
||||
end
|
||||
|
||||
@testset "Forward" begin
|
||||
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
|
||||
reshape(ones(25), :, 1)
|
||||
@test gradient([2, 3]) do x
|
||||
forwarddiff(x) do x
|
||||
x[1]*x[2]
|
||||
end
|
||||
end == ([3, 2],)
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue