943: Fixes #900 r=MikeInnes a=dhairyagandhi96

Thoughts on the test?

cc @MikeInnes

Co-authored-by: Dhairya Gandhi <dhairya@juliacopmuting.com>
This commit is contained in:
bors[bot] 2019-11-26 15:09:27 +00:00 committed by GitHub
commit fb4a48f970
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 1 deletions

View File

@ -6,7 +6,7 @@ using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,

View File

@ -118,6 +118,9 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
) )
end end
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
@nograd conv_transpose_dims
function (c::ConvTranspose)(x::AbstractArray) function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)

View File

@ -1,5 +1,6 @@
using Flux, Test using Flux, Test
using Flux: maxpool, meanpool using Flux: maxpool, meanpool
using Flux: gradient
@testset "Pooling" begin @testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2) x = randn(Float32, 10, 10, 3, 2)
@ -54,6 +55,10 @@ end
y = Conv((3,3), 1 => 1)(x) y = Conv((3,3), 1 => 1)(x)
x_hat = ConvTranspose((3, 3), 1 => 1)(y) x_hat = ConvTranspose((3, 3), 1 => 1)(y)
@test size(x_hat) == size(x) @test size(x_hat) == size(x)
m = ConvTranspose((3,3), 1=>1)
# Test that the gradient call does not throw: #900
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
end end
@testset "CrossCor" begin @testset "CrossCor" begin