Fix issue #323
This commit is contained in:
Mike J Innes 2018-09-06 15:28:15 +01:00 committed by GitHub
commit 5e4ee827e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 8 deletions

View File

@ -1,4 +1,4 @@
import Base: *, == import Base: *
import LinearAlgebra import LinearAlgebra
using Statistics using Statistics
@ -60,9 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
x::TrackedArray == y = data(x) == y for op in [:(==), :≈]
y == x::TrackedArray = y == data(x) @eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
x::TrackedArray == y::TrackedArray = data(x) == data(y) @eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
end
# Array Stdlib # Array Stdlib

View File

@ -30,8 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T") error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) for op in [:(==), :≈, :<]
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
Base.eps(x::TrackedReal) = eps(data(x)) Base.eps(x::TrackedReal) = eps(data(x))

View File

@ -182,9 +182,30 @@ end
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) @test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test (param([1,2,3]) .< 2) == [true, false, false] @testset "equality & order" begin
# TrackedReal
@test param(2)^2 == param(4)
@test param(2)^2 == 4
@test 4 == param(2)^2
@test param(2)^2 == 4.0 @test param(2)^2 param(4)
@test param(2)^2 4
@test 4 param(2)^2
@test (param([1,2,3]) .< 2) == [true, false, false]
@test (param([1,2,3]) .<= 2) == [true, true, false]
@test (2 .> param([1,2,3])) == [true, false, false]
@test (2 .>= param([1,2,3])) == [true, true, false]
# TrackedArray
@test param([1,2,3]).^2 == param([1,4,9])
@test [1,2,3].^2 == param([1,4,9])
@test param([1,2,3]).^2 == [1,4,9]
@test param([1,2,3]).^2 param([1,4,9])
@test [1,2,3].^2 param([1,4,9])
@test param([1,2,3]).^2 [1,4,9]
end
@testset "reshape" begin @testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2) x = reshape(param(rand(2,2,2)), 4, 2)