From 36f5f274a572810891314aa4265833e82aa40d78 Mon Sep 17 00:00:00 2001 From: JohnnyChen Date: Tue, 9 Oct 2018 01:53:32 +0800 Subject: [PATCH 1/4] Support copy(::TrackedArray) 1. fix issue https://github.com/FluxML/Flux.jl/issues/416 2. change test code to pass the test: some broken tests are not broken now... --- src/tracker/scalar.jl | 2 ++ test/tracker.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 1b6098fb..ad7b643d 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -23,6 +23,8 @@ end Base.decompose(x::TrackedReal) = Base.decompose(data(x)) +Base.convert(::Type{T}, x::TrackedReal{S}) where {T<:Real,S} = convert(T, data(x)) + Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) diff --git a/test/tracker.jl b/test/tracker.jl index a4772f2e..7d7168ad 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -40,7 +40,7 @@ function promotiontest(f, A, B, C) if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat] r3 = f(A, B, param(C)) else - @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved + @test_broken f(A, B, param(C)) # until julia#20815 is resolved r3 = r2 end r4 = f(param(A), param(B), param(C)) From 27fec15fcc5fc9af64edf533377f206f2be06443 Mon Sep 17 00:00:00 2001 From: JohnnyChen Date: Tue, 9 Oct 2018 03:34:41 +0800 Subject: [PATCH 2/4] Add explicit copy(x::TrackedArray) method --- src/tracker/array.jl | 2 ++ src/tracker/scalar.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 3d9836d0..b8b06471 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -43,6 +43,8 @@ end Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x)) +Base.copy(x::TrackedArray) = copy(data(x)) + Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index ad7b643d..e0ae7db1 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -23,6 +23,8 @@ end Base.decompose(x::TrackedReal) = Base.decompose(data(x)) +Base.copy(x::TrackedArray) = copy(data(x)) + Base.convert(::Type{T}, x::TrackedReal{S}) where {T<:Real,S} = convert(T, data(x)) Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x From eaacec852fe6a78f7d77bc38e755e1c7c5b1a0d9 Mon Sep 17 00:00:00 2001 From: JohnnyChen Date: Tue, 9 Oct 2018 03:40:02 +0800 Subject: [PATCH 3/4] Bug fix --- src/tracker/scalar.jl | 2 +- test/tracker.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index e0ae7db1..ba83d937 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -23,7 +23,7 @@ end Base.decompose(x::TrackedReal) = Base.decompose(data(x)) -Base.copy(x::TrackedArray) = copy(data(x)) +Base.copy(x::TrackedReal) = copy(data(x)) Base.convert(::Type{T}, x::TrackedReal{S}) where {T<:Real,S} = convert(T, data(x)) diff --git a/test/tracker.jl b/test/tracker.jl index 7d7168ad..a4772f2e 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -40,7 +40,7 @@ function promotiontest(f, A, B, C) if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat] r3 = f(A, B, param(C)) else - @test_broken f(A, B, param(C)) # until julia#20815 is resolved + @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved r3 = r2 end r4 = f(param(A), param(B), param(C)) From de7623ac94a81f47048e5ee149eb5fd449d2cdc5 Mon Sep 17 00:00:00 2001 From: JohnnyChen Date: Tue, 9 Oct 2018 03:49:17 +0800 Subject: [PATCH 4/4] use variable assignment to do "copy" --- src/tracker/array.jl | 2 +- src/tracker/scalar.jl | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index b8b06471..00fe4cc4 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -43,7 +43,7 @@ end Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x)) -Base.copy(x::TrackedArray) = copy(data(x)) +Base.copy(x::TrackedArray) = x Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index ba83d937..e37ee843 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -23,9 +23,7 @@ end Base.decompose(x::TrackedReal) = Base.decompose(data(x)) -Base.copy(x::TrackedReal) = copy(data(x)) - -Base.convert(::Type{T}, x::TrackedReal{S}) where {T<:Real,S} = convert(T, data(x)) +Base.copy(x::TrackedReal) = x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x