Fix back scalar with a Ref and fix diagonal test

This commit is contained in:
Josh Christie 2018-08-11 14:27:56 +01:00
parent 89881a9b21
commit 710a65fe72
2 changed files with 6 additions and 8 deletions

View File

@ -137,7 +137,7 @@ end
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(back(Δ), args)
y, Δ -> getindex.(Ref(back(Δ)), args)
end
function losscheck(x)

View File

@ -3,23 +3,20 @@ using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
using Statistics: mean, std
using Random
# using StatsBase
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
@ -36,7 +33,6 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5))
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
@ -69,6 +65,7 @@ end
@test gradtest(vcatf, rand(5)', rand(2,5))
end
@testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@ -97,7 +94,7 @@ end
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray)
@test !isa(cat(rand(2), dims=1), TrackedArray)
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
@ -115,6 +112,7 @@ end
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@ -128,7 +126,7 @@ end
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(diagm, rand(3))
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
@testset "mean" begin
@test gradtest(mean, rand(2, 3))