construct TrackedScalar with params(1)

This commit is contained in:
CarloLucibello 2017-10-12 10:56:23 +02:00 committed by Mike J Innes
parent 2a66545ef8
commit 00a9e5f01f
2 changed files with 16 additions and 0 deletions

View File

@ -39,6 +39,8 @@ TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs))
param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad

View File

@ -27,4 +27,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
2y + x
end
for T in [Float32, Float64]
@test isa(param(T(1)), TrackedArray{T, 0})
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
end
# TODO: do we wand this behaviour ??
F = typeof(AbstractFloat(1))
for T in [Int32, Int64]
@test isa(param(T(1)), TrackedArray{F, 0})
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
end
end #testset