Update as per new AD
This commit is contained in:
parent
0aabf9d86b
commit
2664a16556
@ -1,4 +1,4 @@
|
|||||||
# Additional Convolution Models
|
# Additional Convolution Layers
|
||||||
|
|
||||||
## Depthwise Convolutions
|
## Depthwise Convolutions
|
||||||
|
|
||||||
|
@ -324,19 +324,15 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||||||
|
|
||||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||||
|
|
||||||
_depthwiseconv(x, w, stride, pad) = depthwiseconv(x, w, stride = stride, pad = pad)
|
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||||
|
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||||
|
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||||
|
|
||||||
depthwiseconv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
@grad depthwiseconv(x, w; kw...) =
|
||||||
track(_depthwiseconv, x, w, stride, pad)
|
depthwiseconv(data(x), data(w); kw...),
|
||||||
depthwiseconv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
Δ -> nobacksies(:depthwiseconv,
|
||||||
track(_depthwiseconv, x, w, stride, pad)
|
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
||||||
depthwiseconv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
|
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
||||||
track(_depthwiseconv, x, w, stride, pad)
|
|
||||||
|
|
||||||
function back(::typeof(_depthwiseconv), Δ, x, w, stride, pad)
|
|
||||||
@back(x, NNlib.∇depthwiseconv_data(Δ, data(x), data(w), stride = stride, pad = pad))
|
|
||||||
@back(w, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
|
|
||||||
end
|
|
||||||
|
|
||||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||||
|
Loading…
Reference in New Issue
Block a user