Split out conv_transpose_dims()
so that Zygote can ignore it
This commit is contained in:
parent
c9148194cf
commit
732f97fe16
@ -102,20 +102,24 @@ ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
|||||||
|
|
||||||
@treelike ConvTranspose
|
@treelike ConvTranspose
|
||||||
|
|
||||||
|
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
||||||
|
# Calculate size of "input", from ∇conv_data()'s perspective...
|
||||||
|
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
|
||||||
|
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
|
||||||
|
C_in = size(c.weight)[end-1]
|
||||||
|
batch_size = size(x)[end]
|
||||||
|
# Create DenseConvDims() that looks like the corresponding conv()
|
||||||
|
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
||||||
|
stride=c.stride,
|
||||||
|
padding=c.pad,
|
||||||
|
dilation=c.dilation,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
function (c::ConvTranspose)(x::AbstractArray)
|
function (c::ConvTranspose)(x::AbstractArray)
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
# Calculate size of "input", from ∇conv_data()'s perspective...
|
cdims = conv_transpose_dims(c, x)
|
||||||
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
|
|
||||||
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
|
|
||||||
C_in = size(c.weight)[end-1]
|
|
||||||
batch_size = size(x)[end]
|
|
||||||
# Create DenseConvDims() that looks like the corresponding conv()
|
|
||||||
cdims = DenseConvDims((I..., C_in, batch_size), size(c.weight);
|
|
||||||
stride=c.stride,
|
|
||||||
padding=c.pad,
|
|
||||||
dilation=c.dilation,
|
|
||||||
)
|
|
||||||
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
|
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user