construct TrackedScalar with params(1)
This commit is contained in:
parent
2a66545ef8
commit
00a9e5f01f
@ -39,6 +39,8 @@ TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
param(xs::Real) = param(fill(xs))
|
||||
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad
|
||||
|
@ -27,4 +27,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
2y + x
|
||||
end
|
||||
|
||||
for T in [Float32, Float64]
|
||||
@test isa(param(T(1)), TrackedArray{T, 0})
|
||||
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
|
||||
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
|
||||
end
|
||||
|
||||
# TODO: do we wand this behaviour ??
|
||||
F = typeof(AbstractFloat(1))
|
||||
for T in [Int32, Int64]
|
||||
@test isa(param(T(1)), TrackedArray{F, 0})
|
||||
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
|
||||
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
Loading…
Reference in New Issue
Block a user