fixes #111
This commit is contained in:
parent
f22cfb5b43
commit
236edbffec
@ -22,6 +22,8 @@ TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
|||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||||
|
|
||||||
|
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
print(io, "TrackedArray{…,$A}")
|
print(io, "TrackedArray{…,$A}")
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: gradcheck
|
using Flux.Tracker: TrackedReal, gradcheck
|
||||||
using NNlib
|
using NNlib
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
@ -67,6 +67,14 @@ end
|
|||||||
@test x.grad == [8]
|
@test x.grad == [8]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Fallbacks" begin
|
||||||
|
xs = param([1 2; 3 4])
|
||||||
|
@test similar(xs) isa Matrix{Float64}
|
||||||
|
# Remove this test if we do LowerTriangular properly
|
||||||
|
L = LowerTriangular(xs)
|
||||||
|
@test L*L' isa Matrix{TrackedReal{Float64}}
|
||||||
|
end
|
||||||
|
|
||||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user