Merge branch 'master' of github.com:MikeInnes/Flux.jl
This commit is contained in:
commit
7bba38274b
@ -45,8 +45,7 @@ tovec(xs::AbstractArray) = vec(xs)
|
|||||||
tovec(xs) = xs
|
tovec(xs) = xs
|
||||||
|
|
||||||
function back!(x::TrackedArray, Δ)
|
function back!(x::TrackedArray, Δ)
|
||||||
Δ′ = vec(x.Δ)
|
x.Δ .+= Δ
|
||||||
Δ′ .+= tovec(Δ)
|
|
||||||
back!(x.f, Δ)
|
back!(x.f, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -68,9 +68,11 @@ function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
|||||||
TrackedArray(Call(b, args...), b())
|
TrackedArray(Call(b, args...), b())
|
||||||
end
|
end
|
||||||
|
|
||||||
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
||||||
|
|
||||||
unbroadcast(x, Δ) =
|
unbroadcast(x, Δ) =
|
||||||
size(x) == size(Δ) ? Δ :
|
size(x) == size(Δ) ? Δ :
|
||||||
sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))
|
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
|
||||||
|
|
||||||
function getpartial(Δ, x, i)
|
function getpartial(Δ, x, i)
|
||||||
@inbounds p = getindex(partials(x), i)
|
@inbounds p = getindex(partials(x), i)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
function gradient(f, xs::AbstractArray...)
|
function gradient(f, xs::AbstractArray...)
|
||||||
xs = track.(xs)
|
xs = track.(xs)
|
||||||
back!(f(xs...), [1])
|
back!(f(xs...))
|
||||||
grad.(xs)
|
grad.(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user