std derivative

This commit is contained in:
Mike J Innes 2017-11-21 17:04:04 +01:00
parent b06884b912
commit 351d3d4771
3 changed files with 11 additions and 2 deletions

View File

@ -83,9 +83,9 @@ end
Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable
vectors α and β:
vectors `α` and `β`:
y = α .* x .+ b
y = α .* x .+ β
The input `x` must be a array where `size(x, 1) == in`.
"""

View File

@ -58,6 +58,12 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))

View File

@ -34,6 +34,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest(rand(5)) do x
y = x.^2
2y + x