This commit is contained in:
Mike J Innes 2018-02-13 10:20:38 +00:00
parent f22cfb5b43
commit 236edbffec
2 changed files with 11 additions and 1 deletions

View File

@ -22,6 +22,8 @@ TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: gradcheck
using Flux.Tracker: TrackedReal, gradcheck
using NNlib
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -67,6 +67,14 @@ end
@test x.grad == [8]
end
@testset "Fallbacks" begin
xs = param([1 2; 3 4])
@test similar(xs) isa Matrix{Float64}
# Remove this test if we do LowerTriangular properly
L = LowerTriangular(xs)
@test L*L' isa Matrix{TrackedReal{Float64}}
end
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
end #testset