conv gradient
This commit is contained in:
parent
5b97d2ba04
commit
d949b31aa5
@ -123,12 +123,22 @@ end
|
|||||||
|
|
||||||
# NNlib
|
# NNlib
|
||||||
|
|
||||||
import NNlib: softmax, ∇softmax
|
using NNlib
|
||||||
|
import NNlib: softmax, ∇softmax, conv2d
|
||||||
|
|
||||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||||
|
|
||||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||||
|
|
||||||
|
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w))
|
||||||
|
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w))
|
||||||
|
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}) = TrackedArray(Call(conv2d, x, w))
|
||||||
|
|
||||||
|
function back(::typeof(conv2d), Δ, x, w)
|
||||||
|
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ))
|
||||||
|
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ))
|
||||||
|
end
|
||||||
|
|
||||||
# Broadcasting
|
# Broadcasting
|
||||||
|
|
||||||
using ForwardDiff: Dual, partials
|
using ForwardDiff: Dual, partials
|
||||||
|
@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
|
|||||||
return grads
|
return grads
|
||||||
end
|
end
|
||||||
|
|
||||||
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))
|
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: gradcheck
|
using Flux.Tracker: gradcheck
|
||||||
|
using NNlib
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||||
@ -45,4 +46,6 @@ end
|
|||||||
2y + x
|
2y + x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user