std derivative
This commit is contained in:
parent
b06884b912
commit
351d3d4771
@ -83,9 +83,9 @@ end
|
||||
Diagonal(in::Integer)
|
||||
|
||||
Creates an element-wise linear transformation layer with learnable
|
||||
vectors α and β:
|
||||
vectors `α` and `β`:
|
||||
|
||||
y = α .* x .+ b
|
||||
y = α .* x .+ β
|
||||
|
||||
The input `x` must be a array where `size(x, 1) == in`.
|
||||
"""
|
||||
|
@ -58,6 +58,12 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
|
||||
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
|
||||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
|
@ -34,6 +34,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
|
Loading…
Reference in New Issue
Block a user