fixed fixes proposed by Carlo
This commit is contained in:
parent
8ab209126d
commit
d6f5baee39
|
@ -27,7 +27,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, similar(x) .= 0)
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
|
|
|
@ -14,13 +14,13 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(2,3))
|
||||
@test gradtest(x -> sum(x), randn(2,3))
|
||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
|
@ -170,9 +170,9 @@ end
|
|||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
@ -216,7 +216,7 @@ end
|
|||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
|
|
Loading…
Reference in New Issue