some more changes
This commit is contained in:
parent
d6f5baee39
commit
ea38c7dbea
|
@ -221,7 +221,7 @@ Base.prod(xs::TrackedArray) = track(prod, xs)
|
|||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
||||
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zeros.(xs)
|
||||
grads = zero.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp = x[i]
|
||||
|
|
|
@ -2,7 +2,7 @@ using Flux
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
using StatsBase
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
@ -14,13 +14,13 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(2,3))
|
||||
@test gradtest(x -> sum(x), randn(2,3))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
|
@ -170,9 +170,9 @@ end
|
|||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
@ -216,7 +216,7 @@ end
|
|||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
|
|
Loading…
Reference in New Issue