commit
5e4ee827e9
|
@ -1,4 +1,4 @@
|
|||
import Base: *, ==
|
||||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
using Statistics
|
||||
|
@ -60,9 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
x::TrackedArray == y = data(x) == y
|
||||
y == x::TrackedArray = y == data(x)
|
||||
x::TrackedArray == y::TrackedArray = data(x) == data(y)
|
||||
for op in [:(==), :≈]
|
||||
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
|
|
|
@ -30,8 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
|
|||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||
for op in [:(==), :≈, :<]
|
||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
|
||||
|
|
|
@ -182,9 +182,30 @@ end
|
|||
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
||||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
@testset "equality & order" begin
|
||||
# TrackedReal
|
||||
@test param(2)^2 == param(4)
|
||||
@test param(2)^2 == 4
|
||||
@test 4 == param(2)^2
|
||||
|
||||
@test param(2)^2 == 4.0
|
||||
@test param(2)^2 ≈ param(4)
|
||||
@test param(2)^2 ≈ 4
|
||||
@test 4 ≈ param(2)^2
|
||||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
@test (param([1,2,3]) .<= 2) == [true, true, false]
|
||||
@test (2 .> param([1,2,3])) == [true, false, false]
|
||||
@test (2 .>= param([1,2,3])) == [true, true, false]
|
||||
|
||||
# TrackedArray
|
||||
@test param([1,2,3]).^2 == param([1,4,9])
|
||||
@test [1,2,3].^2 == param([1,4,9])
|
||||
@test param([1,2,3]).^2 == [1,4,9]
|
||||
|
||||
@test param([1,2,3]).^2 ≈ param([1,4,9])
|
||||
@test [1,2,3].^2 ≈ param([1,4,9])
|
||||
@test param([1,2,3]).^2 ≈ [1,4,9]
|
||||
end
|
||||
|
||||
@testset "reshape" begin
|
||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
||||
|
|
Loading…
Reference in New Issue