1.0 fix for conv transpose

This commit is contained in:
Tejan Karmali 2018-09-08 15:44:06 -04:00
parent d5d9441fc1
commit e86365ed3f
4 changed files with 60 additions and 4 deletions

View File

@ -5,7 +5,7 @@ module Flux
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu

View File

@ -1,4 +1,4 @@
using NNlib: conv using NNlib: conv, ∇conv_data
@generated sub2(::Val{N}) where N = :(Val($(N-2))) @generated sub2(::Val{N}) where N = :(Val($(N-2)))
@ -51,6 +51,48 @@ function Base.show(io::IO, l::Conv)
print(io, ")") print(io, ")")
end end
"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N =
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose
function (c::ConvTranspose)(x)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end
function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
""" """
MaxPool(k) MaxPool(k)

View File

@ -289,7 +289,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
# NNlib # NNlib
using NNlib using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, maxpool, meanpool
softmax(xs::TrackedArray) = track(softmax, xs) softmax(xs::TrackedArray) = track(softmax, xs)
@ -309,6 +309,16 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...), (NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)
@grad ∇conv_data(x, w; kw...) =
∇conv_data(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.conv(data.((x, Δ, w))...; kw...),
NNlib.∇conv_filter(data.((x, Δ, w))...; kw...)))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
@grad function maxpool(x, k; kw...) @grad function maxpool(x, k; kw...)

View File

@ -1,7 +1,7 @@
using Flux using Flux
using Flux.Tracker, Test, NNlib using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv using NNlib: conv, ∇conv_data
using Printf: @sprintf using Printf: @sprintf
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
using Statistics: mean, std using Statistics: mean, std
@ -176,6 +176,10 @@ end
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))