diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a13dab85..76552db4 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -150,46 +150,6 @@ function Base.show(io::IO, l::DepthwiseConv) print(io, ")") end -""" - ConvTranspose(size, in=>out) - ConvTranspose(size, in=>out, relu) - -Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`. -`in` and `out` specify the number of input and output channels respectively. -Data should be stored in WHCN order. In other words, a 100×100 RGB image would -be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. -Takes the keyword arguments `pad`, `stride` and `dilation`. -""" -struct ConvTranspose{N,F,A,V} - σ::F - weight::A - bias::V - stride::NTuple{N,Int} - pad::NTuple{N,Int} - dilation::NTuple{N,Int} -end - -ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation = 1) where {T,N} = - ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) - -ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, - stride = 1, pad = 0, dilation = 1) where N = -ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ, - stride = stride, pad = pad, dilation = dilation) - -@treelike ConvTranspose - -function (c::ConvTranspose)(x) - # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) -end - -function Base.show(io::IO, l::ConvTranspose) - print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2]) -end - """ MaxPool(k) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 21ad6b26..4abd6e03 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -356,12 +356,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y) # NNlib using NNlib -<<<<<<< HEAD:src/tracker/lib/array.jl import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool -======= -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, - conv, ∇conv_data, depthwiseconv, maxpool, meanpool ->>>>>>> a657c287d0590fdd9e49bb68c35bf96febe45e6d:src/tracker/array.jl softmax(xs::TrackedArray) = track(softmax, xs)