see #205
This commit is contained in:
parent
1105e3ac20
commit
2370bdbe91
@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
|||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
# This cuts derivatives, fix if needed.
|
||||||
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||||
|
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||||
|
|
||||||
@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
|||||||
|
|
||||||
back(::typeof(getindex), Δ, t, i) =
|
back(::typeof(getindex), Δ, t, i) =
|
||||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||||
|
|
||||||
|
# Array collection
|
||||||
|
|
||||||
|
function collect(xs)
|
||||||
|
xs = Base.collect(xs)
|
||||||
|
track(Call(collect, xs), data.(xs))
|
||||||
|
end
|
||||||
|
|
||||||
|
function scan(c::Call{typeof(collect)})
|
||||||
|
foreach(scan, c.args[1])
|
||||||
|
end
|
||||||
|
|
||||||
|
function back(::typeof(collect), Δ, xs)
|
||||||
|
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
|
||||||
|
end
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck
|
using Flux.Tracker: TrackedReal, gradcheck, grad
|
||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
@ -220,4 +220,13 @@ b = param(rand())
|
|||||||
Tracker.back!(b)
|
Tracker.back!(b)
|
||||||
@test Tracker.grad(b) == 1
|
@test Tracker.grad(b) == 1
|
||||||
|
|
||||||
|
@testset "collect" begin
|
||||||
|
x, y = param(2), param(3)
|
||||||
|
xy = Tracker.collect([x, y])
|
||||||
|
@test xy isa TrackedArray{Float64}
|
||||||
|
z = xy[1]*xy[2]
|
||||||
|
back!(z)
|
||||||
|
@test grad.((x,y)) == (3, 2)
|
||||||
|
end
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user