diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index a2a6c745..8e6a584a 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -39,6 +39,8 @@ TrackedArray(c::Call) = TrackedArray(c, c()) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) param(xs) = TrackedArray(AbstractFloat.(xs)) +param(xs::Real) = param(fill(xs)) + istracked(x::TrackedArray) = true data(x::TrackedArray) = x.data grad(x::TrackedArray) = x.grad diff --git a/test/tracker.jl b/test/tracker.jl index dca7e13c..2a20338e 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -27,4 +27,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) 2y + x end +for T in [Float32, Float64] + @test isa(param(T(1)), TrackedArray{T, 0}) + @test isa(param(rand(T, 2)), TrackedArray{T, 1}) + @test isa(param(rand(T, 2,2)), TrackedArray{T, 2}) end + +# TODO: do we wand this behaviour ?? +F = typeof(AbstractFloat(1)) +for T in [Int32, Int64] + @test isa(param(T(1)), TrackedArray{F, 0}) + @test isa(param(rand(T, 2)), TrackedArray{F, 1}) + @test isa(param(rand(T, 2,2)), TrackedArray{F, 2}) +end + +end #testset