Update as per new AD
This commit is contained in:
parent
0aabf9d86b
commit
2664a16556
|
@ -1,4 +1,4 @@
|
|||
# Additional Convolution Models
|
||||
# Additional Convolution Layers
|
||||
|
||||
## Depthwise Convolutions
|
||||
|
||||
|
|
|
@ -324,19 +324,15 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
_depthwiseconv(x, w, stride, pad) = depthwiseconv(x, w, stride = stride, pad = pad)
|
||||
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
||||
|
||||
depthwiseconv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_depthwiseconv, x, w, stride, pad)
|
||||
depthwiseconv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_depthwiseconv, x, w, stride, pad)
|
||||
depthwiseconv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_depthwiseconv, x, w, stride, pad)
|
||||
|
||||
function back(::typeof(_depthwiseconv), Δ, x, w, stride, pad)
|
||||
@back(x, NNlib.∇depthwiseconv_data(Δ, data(x), data(w), stride = stride, pad = pad))
|
||||
@back(w, NNlib.∇depthwiseconv_filter(Δ, data(x), data(w), stride = stride, pad = pad))
|
||||
end
|
||||
@grad depthwiseconv(x, w; kw...) =
|
||||
depthwiseconv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:depthwiseconv,
|
||||
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
|
|
Loading…
Reference in New Issue