see #205
This commit is contained in:
parent
1105e3ac20
commit
2370bdbe91
|
@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
|||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
# This cuts derivatives, fix if needed.
|
||||
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
|
@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
|||
|
||||
back(::typeof(getindex), Δ, t, i) =
|
||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, xs), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back(::typeof(collect), Δ, xs)
|
||||
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
|
||||
end
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad
|
||||
using NNlib: conv
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
|
@ -220,4 +220,13 @@ b = param(rand())
|
|||
Tracker.back!(b)
|
||||
@test Tracker.grad(b) == 1
|
||||
|
||||
@testset "collect" begin
|
||||
x, y = param(2), param(3)
|
||||
xy = Tracker.collect([x, y])
|
||||
@test xy isa TrackedArray{Float64}
|
||||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue