fixes #111
This commit is contained in:
parent
f22cfb5b43
commit
236edbffec
|
@ -22,6 +22,8 @@ TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
|||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: gradcheck
|
||||
using Flux.Tracker: TrackedReal, gradcheck
|
||||
using NNlib
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
|
@ -67,6 +67,14 @@ end
|
|||
@test x.grad == [8]
|
||||
end
|
||||
|
||||
@testset "Fallbacks" begin
|
||||
xs = param([1 2; 3 4])
|
||||
@test similar(xs) isa Matrix{Float64}
|
||||
# Remove this test if we do LowerTriangular properly
|
||||
L = LowerTriangular(xs)
|
||||
@test L*L' isa Matrix{TrackedReal{Float64}}
|
||||
end
|
||||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue