From 00a9e5f01fb283323230361658e060810a1973f1 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 12 Oct 2017 10:56:23 +0200 Subject: [PATCH] construct TrackedScalar with params(1) --- src/tracker/Tracker.jl | 2 ++ test/tracker.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index a2a6c745..8e6a584a 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -39,6 +39,8 @@ TrackedArray(c::Call) = TrackedArray(c, c()) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) param(xs) = TrackedArray(AbstractFloat.(xs)) +param(xs::Real) = param(fill(xs)) + istracked(x::TrackedArray) = true data(x::TrackedArray) = x.data grad(x::TrackedArray) = x.grad diff --git a/test/tracker.jl b/test/tracker.jl index dca7e13c..2a20338e 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -27,4 +27,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) 2y + x end +for T in [Float32, Float64] + @test isa(param(T(1)), TrackedArray{T, 0}) + @test isa(param(rand(T, 2)), TrackedArray{T, 1}) + @test isa(param(rand(T, 2,2)), TrackedArray{T, 2}) end + +# TODO: do we wand this behaviour ?? +F = typeof(AbstractFloat(1)) +for T in [Int32, Int64] + @test isa(param(T(1)), TrackedArray{F, 0}) + @test isa(param(rand(T, 2)), TrackedArray{F, 1}) + @test isa(param(rand(T, 2,2)), TrackedArray{F, 2}) +end + +end #testset