fixes #114
This commit is contained in:
parent
b7b6c975bc
commit
29787eba45
@ -58,6 +58,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
|||||||
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
||||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||||
|
|
||||||
|
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||||
|
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||||
|
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||||
|
|
||||||
|
function back(::typeof(dot), Δ, xs, ys)
|
||||||
|
@back(xs, Δ.*ys)
|
||||||
|
@back(ys, Δ.*xs)
|
||||||
|
end
|
||||||
|
|
||||||
# Hacks to get std working
|
# Hacks to get std working
|
||||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||||
|
@ -38,6 +38,8 @@ end
|
|||||||
@test gradtest(x -> std(x), rand(5,5))
|
@test gradtest(x -> std(x), rand(5,5))
|
||||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||||
|
|
||||||
|
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||||
|
|
||||||
@test gradtest(rand(5)) do x
|
@test gradtest(rand(5)) do x
|
||||||
y = x.^2
|
y = x.^2
|
||||||
2y + x
|
2y + x
|
||||||
|
Loading…
Reference in New Issue
Block a user