Split out `conv_transpose_dims()` so that Zygote can ignore it
This commit is contained in:
parent
c9148194cf
commit
732f97fe16
|
@ -102,20 +102,24 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
|||
|
||||
@treelike ConvTranspose
|
||||
|
||||
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
|
||||
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
|
||||
C_in = size(c.weight)[end-1]
|
||||
batch_size = size(x)[end]
|
||||
# Create DenseConvDims() that looks like the corresponding conv()
|
||||
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
||||
stride=c.stride,
|
||||
padding=c.pad,
|
||||
dilation=c.dilation,
|
||||
)
|
||||
end
|
||||
|
||||
function (c::ConvTranspose)(x::AbstractArray)
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
|
||||
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
|
||||
C_in = size(c.weight)[end-1]
|
||||
batch_size = size(x)[end]
|
||||
# Create DenseConvDims() that looks like the corresponding conv()
|
||||
cdims = DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
||||
stride=c.stride,
|
||||
padding=c.pad,
|
||||
dilation=c.dilation,
|
||||
)
|
||||
cdims = conv_transpose_dims(c, x)
|
||||
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue